import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class SRMConv2d_simple(nn.Module):
def __init__(self, inc=3, learnable=False):
super(SRMConv2d_simple, self).__init__()
self.truc = nn.Hardtanh(-3, 3)
kernel = self._build_kernel(inc) # (3,3,5,5)
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
def forward(self, x):
'''
x: imgs (Batch, H, W, 3)
'''
out = F.conv2d(x, self.kernel, stride=1, padding=2)
out = self.truc(out)
return out
def _build_kernel(self, inc):
# filter1: KB
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
# filter2:KV
filter2 = [[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1]]
# filter3:hor 2rd
filter3 = [[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / 4.
filter2 = np.asarray(filter2, dtype=float) / 12.
filter3 = np.asarray(filter3, dtype=float) / 2.
# statck the filters
filters = [[filter1],#, filter1, filter1],
[filter2],#, filter2, filter2],
[filter3]]#, filter3, filter3]] # (3,3,5,5)
filters = np.array(filters)
filters = np.repeat(filters, inc, axis=1)
filters = torch.FloatTensor(filters) # (3,3,5,5)
return filters
class SRMConv2d_Separate(nn.Module):
def __init__(self, inc, outc, learnable=False):
super(SRMConv2d_Separate, self).__init__()
self.inc = inc
self.truc = nn.Hardtanh(-3, 3)
kernel = self._build_kernel(inc) # (3,3,5,5)
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
self.out_conv = nn.Sequential(
nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
nn.BatchNorm2d(outc),
nn.ReLU(inplace=True)
)
for ly in self.out_conv.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
def forward(self, x):
'''
x: imgs (Batch, H, W, 3)
'''
out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
out = self.truc(out)
out = self.out_conv(out)
return out
def _build_kernel(self, inc):
# filter1: KB
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
# filter2:KV
filter2 = [[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1]]
# # filter3:hor 2rd
filter3 = [[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / 4.
filter2 = np.asarray(filter2, dtype=float) / 12.
filter3 = np.asarray(filter3, dtype=float) / 2.
# statck the filters
filters = [[filter1],#, filter1, filter1],
[filter2],#, filter2, filter2],
[filter3]]#, filter3, filter3]] # (3,3,5,5)
filters = np.array(filters)
# filters = np.repeat(filters, inc, axis=1)
filters = np.repeat(filters, inc, axis=0)
filters = torch.FloatTensor(filters) # (3,3,5,5)
# print(filters.size())
return filters
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
"""
Channel Attention and Spaitial Attention from
Woo, S., Park, J., Lee, J.Y., & Kweon, I. CBAM: Convolutional Block Attention Module. ECCV2018.
"""
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=8):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.sharedMLP = nn.Sequential(
nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=0.02)
def forward(self, x):
avgout = self.sharedMLP(self.avg_pool(x))
maxout = self.sharedMLP(self.max_pool(x))
return self.sigmoid(avgout + maxout)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=0.02)
def forward(self, x):
avgout = torch.mean(x, dim=1, keepdim=True)
maxout, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avgout, maxout], dim=1)
x = self.conv(x)
return self.sigmoid(x)
"""
The following modules are modified based on https://github.com/heykeetae/Self-Attention-GAN
"""
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim, out_dim=None, add=False, ratio=8):
super(Self_Attn, self).__init__()
self.chanel_in = in_dim
self.add = add
if out_dim is None:
out_dim = in_dim
self.out_dim = out_dim
# self.activation = activation
self.query_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.key_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.value_conv = nn.Conv2d(
in_channels=in_dim, out_channels=out_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(
m_batchsize, -1, width*height).permute(0, 2, 1) # B X C X(N)
proj_key = self.key_conv(x).view(
m_batchsize, -1, width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query, proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(
m_batchsize, -1, width*height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, self.out_dim, width, height)
if self.add:
out = self.gamma*out + x
else:
out = self.gamma*out
return out # , attention
class CrossModalAttention(nn.Module):
""" CMA attention Layer"""
def __init__(self, in_dim, activation=None, ratio=8, cross_value=True):
super(CrossModalAttention, self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.cross_value = cross_value
self.query_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.key_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.value_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=0.02)
def forward(self, x, y):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
B, C, H, W = x.size()
proj_query = self.query_conv(x).view(
B, -1, H*W).permute(0, 2, 1) # B , HW, C
proj_key = self.key_conv(y).view(
B, -1, H*W) # B X C x (*W*H)
energy = torch.bmm(proj_query, proj_key) # B, HW, HW
attention = self.softmax(energy) # BX (N) X (N)
if self.cross_value:
proj_value = self.value_conv(y).view(
B, -1, H*W) # B , C , HW
else:
proj_value = self.value_conv(x).view(
B, -1, H*W) # B , C , HW
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(B, C, H, W)
out = self.gamma*out + x
if self.activation is not None:
out = self.activation(out)
return out # , attention
class DualCrossModalAttention(nn.Module):
""" Dual CMA attention Layer"""
def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False):
super(DualCrossModalAttention, self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.ret_att = ret_att
# query conv
self.key_conv1 = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.key_conv2 = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
self.key_conv_share = nn.Conv2d(
in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1)
self.linear1 = nn.Linear(size*size, size*size)
self.linear2 = nn.Linear(size*size, size*size)
# separated value conv
self.value_conv1 = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma1 = nn.Parameter(torch.zeros(1))
self.value_conv2 = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma2 = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=0.02)
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data, gain=0.02)
def forward(self, x, y):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
B, C, H, W = x.size()
def _get_att(a, b):
proj_key1 = self.key_conv_share(self.key_conv1(a)).view(
B, -1, H*W).permute(0, 2, 1) # B, HW, C
proj_key2 = self.key_conv_share(self.key_conv2(b)).view(
B, -1, H*W) # B X C x (*W*H)
energy = torch.bmm(proj_key1, proj_key2) # B, HW, HW
attention1 = self.softmax(self.linear1(energy))
attention2 = self.softmax(self.linear2(
energy.permute(0, 2, 1))) # BX (N) X (N)
return attention1, attention2
att_y_on_x, att_x_on_y = _get_att(x, y)
proj_value_y_on_x = self.value_conv2(y).view(
B, -1, H*W) # B, C, HW
out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1))
out_y_on_x = out_y_on_x.view(B, C, H, W)
out_x = self.gamma1*out_y_on_x + x
proj_value_x_on_y = self.value_conv1(x).view(
B, -1, H*W) # B , C , HW
out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1))
out_x_on_y = out_x_on_y.view(B, C, H, W)
out_y = self.gamma2*out_x_on_y + y
if self.ret_att:
return out_x, out_y, att_y_on_x, att_x_on_y
return out_x, out_y # , attention
if __name__ == "__main__":
x = torch.rand(10, 768, 16, 16)
y = torch.rand(10, 768, 16, 16)
dcma = DualCrossModalAttention(768, ret_att=True)
out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y)
print(out_y.size())
print(att_x_on_y.size())
torch.Size([10, 768, 16, 16]) torch.Size([10, 256, 256])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class AngleSimpleLinear(nn.Module):
"""Computes cos of angles between input vectors and weights vectors"""
def __init__(self, in_features, out_features):
super(AngleSimpleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features, out_features))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
def forward(self, x):
cos_theta = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
return cos_theta.clamp(-1, 1)
def focal_loss(input_values, gamma):
"""Computes the focal loss"""
p = torch.exp(-input_values)
loss = (1 - p) ** gamma * input_values
return loss.mean()
class AMSoftmaxLoss(nn.Module):
"""Computes the AM-Softmax loss with cos or arc margin"""
margin_types = ['cos', 'arc']
def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1.):
super(AMSoftmaxLoss, self).__init__()
assert margin_type in AMSoftmaxLoss.margin_types
self.margin_type = margin_type
assert gamma >= 0
self.gamma = gamma
assert m > 0
self.m = m
assert s > 0
self.s = s
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
assert t >= 1
self.t = t
def forward(self, cos_theta, target):
if self.margin_type == 'cos':
phi_theta = cos_theta - self.m
else:
sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
phi_theta = cos_theta * self.cos_m - sine * self.sin_m #cos(theta+m)
phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m)
index = torch.zeros_like(cos_theta, dtype=torch.uint8)
index.scatter_(1, target.data.view(-1, 1), 1)
output = torch.where(index, phi_theta, cos_theta)
if self.gamma == 0 and self.t == 1.:
return F.cross_entropy(self.s*output, target)
if self.t > 1:
h_theta = self.t - 1 + self.t*cos_theta
support_vecs_mask = (1 - index) * \
torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0)
output = torch.where(support_vecs_mask, h_theta, output)
return F.cross_entropy(self.s*output, target)
return focal_loss(F.cross_entropy(self.s*output, target, reduction='none'), self.gamma)
"""
Code from https://github.com/ondyari/FaceForensics
Author: Andreas Rössler
"""
import os
import argparse
import torch
# import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
# from lib.nets.xception import xception
import math
import torchvision
# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
pretrained_settings = {
'xception': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000,
'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
}
PRETAINED_WEIGHT_PATH = '/kaggle/input/xceptionb5690688pth/xception-b5690688.pth'
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
stride, padding, dilation, groups=in_channels, bias=bias)
self.pointwise = nn.Conv2d(
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
super(Block, self).__init__()
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters,
1, stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = in_filters
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
for i in range(reps-1):
rep.append(self.relu)
rep.append(SeparableConv2d(filters, filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
if strides != 1:
rep.append(nn.MaxPool2d(3, strides, 1))
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x += skip
return x
def add_gaussian_noise(ins, mean=0, stddev=0.2):
noise = ins.data.new(ins.size()).normal_(mean, stddev)
return ins + noise
class Xception(nn.Module):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self, num_classes=1000, inc=3):
""" Constructor
Args:
num_classes: number of classes
"""
super(Xception, self).__init__()
self.num_classes = num_classes
# Entry flow
self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
self.bn2 = nn.BatchNorm2d(64)
# do relu here
self.block1 = Block(
64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(
128, 256, 2, 2, start_with_relu=True, grow_first=True)
self.block3 = Block(
256, 728, 2, 2, start_with_relu=True, grow_first=True)
# middle flow
self.block4 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block5 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block6 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block7 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block8 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block9 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block10 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block11 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
# Exit flow
self.block12 = Block(
728, 1024, 2, 2, start_with_relu=True, grow_first=False)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(1536)
# do relu here
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(2048)
self.fc = nn.Linear(2048, num_classes)
# #------- init weights --------
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# #-----------------------------
def fea_part1_0(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
def fea_part1_1(self, x):
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
def fea_part1(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
def fea_part2(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
def fea_part3(self, x):
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
return x
def fea_part4(self, x):
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
return x
def fea_part5(self, x):
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
return x
def features(self, input):
x = self.fea_part1(input)
x = self.fea_part2(x)
x = self.fea_part3(x)
x = self.fea_part4(x)
x = self.fea_part5(x)
return x
def classifier(self, features):
x = self.relu(features)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
out = self.last_linear(x)
return out, x
def forward(self, input):
x = self.features(input)
out, x = self.classifier(x)
return out, x
def xception(num_classes=1000, pretrained='imagenet', inc=3):
model = Xception(num_classes=num_classes, inc=inc)
if pretrained:
settings = pretrained_settings['xception'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(
settings['num_classes'], num_classes)
model = Xception(num_classes=num_classes)
model.load_state_dict(model_zoo.load_url(settings['url']))
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
# TODO: ugly
model.last_linear = model.fc
del model.fc
return model
class TransferModel(nn.Module):
"""
Simple transfer learning model that takes an imagenet pretrained model with
a fc layer as base model and retrains a new fc layer for num_out_classes
"""
def __init__(self, modelchoice, num_out_classes=2, dropout=0.0,
weight_norm=False, return_fea=False, inc=3):
super(TransferModel, self).__init__()
self.modelchoice = modelchoice
self.return_fea = return_fea
if modelchoice == 'xception':
def return_pytorch04_xception(pretrained=True):
# Raises warning "src not broadcastable to dst" but thats fine
model = xception(pretrained=False)
if pretrained:
# Load model in torch 0.4+
model.fc = model.last_linear
del model.last_linear
state_dict = torch.load(
PRETAINED_WEIGHT_PATH)
for name, weights in state_dict.items():
if 'pointwise' in name:
state_dict[name] = weights.unsqueeze(
-1).unsqueeze(-1)
model.load_state_dict(state_dict)
model.last_linear = model.fc
del model.fc
return model
self.model = return_pytorch04_xception()
# Replace fc
num_ftrs = self.model.last_linear.in_features
if not dropout:
if weight_norm:
print('Using Weight_Norm')
self.model.last_linear = nn.utils.weight_norm(
nn.Linear(num_ftrs, num_out_classes), name='weight')
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
else:
print('Using dropout', dropout)
if weight_norm:
print('Using Weight_Norm')
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.utils.weight_norm(
nn.Linear(num_ftrs, num_out_classes), name='weight')
)
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
if inc != 3:
self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02)
elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
if modelchoice == 'resnet50':
self.model = torchvision.models.resnet50(pretrained=True)
if modelchoice == 'resnet18':
self.model = torchvision.models.resnet18(pretrained=True)
# Replace fc
num_ftrs = self.model.fc.in_features
if not dropout:
self.model.fc = nn.Linear(num_ftrs, num_out_classes)
else:
self.model.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
else:
raise Exception('Choose valid model, e.g. resnet50')
def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
"""
Freezes all layers below a specific layer and sets the following layers
to true if boolean else only the fully connected final layer
:param boolean:
:param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
:return:
"""
# Stage-1: freeze all the layers
if layername is None:
for i, param in self.model.named_parameters():
param.requires_grad = True
return
else:
for i, param in self.model.named_parameters():
param.requires_grad = False
if boolean:
# Make all layers following the layername layer trainable
ct = []
found = False
for name, child in self.model.named_children():
if layername in ct:
found = True
for params in child.parameters():
params.requires_grad = True
ct.append(name)
if not found:
raise NotImplementedError('Layer not found, cant finetune!'.format(
layername))
else:
if self.modelchoice == 'xception':
# Make fc trainable
for param in self.model.last_linear.parameters():
param.requires_grad = True
else:
# Make fc trainable
for param in self.model.fc.parameters():
param.requires_grad = True
def forward(self, x):
out, x = self.model(x)
if self.return_fea:
return out, x
else:
return out
def features(self, x):
x = self.model.features(x)
return x
def classifier(self, x):
out, x = self.model.classifier(x)
return out, x
def model_selection(modelname, num_out_classes,
dropout=None):
"""
:param modelname:
:return: model, image size, pretraining<yes/no>, input_list
"""
if modelname == 'xception':
return TransferModel(modelchoice='xception',
num_out_classes=num_out_classes), 299, \
True, ['image'], None
elif modelname == 'resnet18':
return TransferModel(modelchoice='resnet18', dropout=dropout,
num_out_classes=num_out_classes), \
224, True, ['image'], None
else:
raise NotImplementedError(modelname)
if __name__ == '__main__':
model = TransferModel('xception', dropout=0.5)
print(model)
# model = model.cuda()
# from torchsummary import summary
# input_s = (3, image_size, image_size)
# print(summary(model, input_s))
dummy = torch.rand(10, 3, 256, 256)
out = model(dummy)
print(out.size())
x = model.features(dummy)
out, x = model.classifier(x)
print(out.size())
print(x.size())
Using dropout 0.5
TransferModel(
(model): Xception(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(block1): Block(
(skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): SeparableConv2d(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block2): Block(
(skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block3): Block(
(skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block4): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block5): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block6): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block7): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block8): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block9): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block10): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block11): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block12): Block(
(skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(conv3): SeparableConv2d(
(conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
(pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): SeparableConv2d(
(conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
(pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(last_linear): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=2048, out_features=2, bias=True)
)
)
)
torch.Size([10, 2])
torch.Size([10, 2])
torch.Size([10, 2048])
import torch
import torch.nn as nn
import torch.nn.functional as F
# from components.attention import ChannelAttention, SpatialAttention, DualCrossModalAttention
# from components.srm_conv import SRMConv2d_simple, SRMConv2d_Separate
# from networks.xception import TransferModel
class SRMPixelAttention(nn.Module):
def __init__(self, in_channels):
super(SRMPixelAttention, self).__init__()
self.srm = SRMConv2d_simple()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.pa = SpatialAttention()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=1)
if not m.bias is None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x_srm = self.srm(x)
fea = self.conv(x_srm)
att_map = self.pa(fea)
return att_map
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = nn.Sequential(
nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_chan),
nn.ReLU()
)
self.ca = ChannelAttention(out_chan, ratio=16)
self.init_weight()
def forward(self, x, y):
fuse_fea = self.convblk(torch.cat((x, y), dim=1))
fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
return fuse_fea
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None:
nn.init.constant_(ly.bias, 0)
class Two_Stream_Net(nn.Module):
def __init__(self):
super().__init__()
self.xception_rgb = TransferModel(
'xception', dropout=0.5, inc=3, return_fea=True)
self.xception_srm = TransferModel(
'xception', dropout=0.5, inc=3, return_fea=True)
self.srm_conv0 = SRMConv2d_simple(inc=3)
self.srm_conv1 = SRMConv2d_Separate(32, 32)
self.srm_conv2 = SRMConv2d_Separate(64, 64)
self.relu = nn.ReLU(inplace=True)
self.att_map = None
self.srm_sa = SRMPixelAttention(3)
self.srm_sa_post = nn.Sequential(
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)
self.fusion = FeatureFusionModule()
self.att_dic = {}
def features(self, x):
srm = self.srm_conv0(x)
x = self.xception_rgb.model.fea_part1_0(x)
y = self.xception_srm.model.fea_part1_0(srm) \
+ self.srm_conv1(x)
y = self.relu(y)
x = self.xception_rgb.model.fea_part1_1(x)
y = self.xception_srm.model.fea_part1_1(y) \
+ self.srm_conv2(x)
y = self.relu(y)
# srm guided spatial attention
self.att_map = self.srm_sa(srm)
x = x * self.att_map + x
x = self.srm_sa_post(x)
x = self.xception_rgb.model.fea_part2(x)
y = self.xception_srm.model.fea_part2(y)
x, y = self.dual_cma0(x, y)
x = self.xception_rgb.model.fea_part3(x)
y = self.xception_srm.model.fea_part3(y)
x, y = self.dual_cma1(x, y)
x = self.xception_rgb.model.fea_part4(x)
y = self.xception_srm.model.fea_part4(y)
x = self.xception_rgb.model.fea_part5(x)
y = self.xception_srm.model.fea_part5(y)
fea = self.fusion(x, y)
return fea
def classifier(self, fea):
out, fea = self.xception_rgb.classifier(fea)
return out, fea
def forward(self, x):
'''
x: original rgb
'''
out, fea = self.classifier(self.features(x))
# return out, fea, self.att_map
return out
if __name__ == '__main__':
model = Two_Stream_Net()
dummy = torch.rand((1,3,256,256))
out = model(dummy)
print(model)
Using dropout 0.5
Using dropout 0.5
Two_Stream_Net(
(xception_rgb): TransferModel(
(model): Xception(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(block1): Block(
(skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): SeparableConv2d(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block2): Block(
(skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block3): Block(
(skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block4): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block5): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block6): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block7): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block8): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block9): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block10): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block11): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block12): Block(
(skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(conv3): SeparableConv2d(
(conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
(pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): SeparableConv2d(
(conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
(pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(last_linear): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=2048, out_features=2, bias=True)
)
)
)
(xception_srm): TransferModel(
(model): Xception(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(block1): Block(
(skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): SeparableConv2d(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block2): Block(
(skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block3): Block(
(skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(block4): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block5): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block6): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block7): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block8): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block9): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block10): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block11): Block(
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block12): Block(
(skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(rep): Sequential(
(0): ReLU()
(1): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): SeparableConv2d(
(conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
(pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
)
(conv3): SeparableConv2d(
(conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
(pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): SeparableConv2d(
(conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
(pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(last_linear): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=2048, out_features=2, bias=True)
)
)
)
(srm_conv0): SRMConv2d_simple(
(truc): Hardtanh(min_val=-3, max_val=3)
)
(srm_conv1): SRMConv2d_Separate(
(truc): Hardtanh(min_val=-3, max_val=3)
(out_conv): Sequential(
(0): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(srm_conv2): SRMConv2d_Separate(
(truc): Hardtanh(min_val=-3, max_val=3)
(out_conv): Sequential(
(0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(relu): ReLU(inplace=True)
(srm_sa): SRMPixelAttention(
(srm): SRMConv2d_simple(
(truc): Hardtanh(min_val=-3, max_val=3)
)
(conv): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(pa): SpatialAttention(
(conv): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
(sigmoid): Sigmoid()
)
)
(srm_sa_post): Sequential(
(0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
)
(dual_cma0): DualCrossModalAttention(
(key_conv1): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
(key_conv2): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
(key_conv_share): Conv2d(91, 91, kernel_size=(1, 1), stride=(1, 1))
(linear1): Linear(in_features=256, out_features=256, bias=True)
(linear2): Linear(in_features=256, out_features=256, bias=True)
(value_conv1): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
(value_conv2): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
(softmax): Softmax(dim=-1)
)
(dual_cma1): DualCrossModalAttention(
(key_conv1): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
(key_conv2): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
(key_conv_share): Conv2d(91, 91, kernel_size=(1, 1), stride=(1, 1))
(linear1): Linear(in_features=256, out_features=256, bias=True)
(linear2): Linear(in_features=256, out_features=256, bias=True)
(value_conv1): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
(value_conv2): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
(softmax): Softmax(dim=-1)
)
(fusion): FeatureFusionModule(
(convblk): Sequential(
(0): Conv2d(4096, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(ca): ChannelAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(max_pool): AdaptiveMaxPool2d(output_size=1)
(sharedMLP): Sequential(
(0): Conv2d(2048, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(128, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(sigmoid): Sigmoid()
)
)
)
import numpy as np # linear algebra
import pandas as pd
from glob import glob
# from retinaface import RetinaFace
import torch
from torch import optim
import torchvision
import timm
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import seaborn as sns
from PIL import Image
import random
import os
from torchvision.transforms import v2
from torch.utils.data import Dataset , DataLoader
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations import (
Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip,
RandomBrightnessContrast, Rotate, ShiftScaleRotate, Transpose
)
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold
import torch.nn as nn
from contextlib import contextmanager
from torch.optim import Adam, SGD
from functools import partial
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import time
from sklearn.metrics import roc_auc_score
import math
from catalyst.data import BalanceClassSampler
txt_to_csv = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIR_PATH = "/kaggle/input/deepfake/phase1"
TRAIN_DIR = "/kaggle/input/deepfake/phase1/trainset"
TEST_DIR = "/kaggle/input/deepfake/phase1/valset"
OUTPUT_DIR = "/kaggle/working/"
class CFG :
seed = 42
n_fold = 5
target_col = 'target'
train=True
inference=False
pseudo_labeling = True
num_classes = 2 #binary class
trn_fold=[0, 1]
debug=False
apex=False
print_freq=20 #every how many batch the scores get showed
num_workers=4
# model_name="eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
# model_name= "efficientnet_b3"
size=256
scheduler='CosineAnnealingWarmRestarts'
epochs=2
lr=1e-4
min_lr=1e-6
T_0=10 # CosineAnnealingWarmRestarts
batch_size=20
weight_decay=1e-6
gradient_accumulation_steps=1
max_grad_norm=1000
train = pd.read_csv(f"{DIR_PATH +'/trainset_label.txt'}")
test = pd.read_csv(f"{DIR_PATH +'/valset_label.txt'}")
if CFG.pseudo_labeling :
ps = pd.read_csv('/kaggle/input/pseudolabling/b4_nTTA.csv')
ps.rename(columns = {"label" : "target"} , inplace = True)
to_add = ps[(ps['target']>0.9) | (ps['target']<0.1)]
# print(to_add.shape)
to_add["target"] = [1 if i>0.9 else 0 for i in to_add['target']]
print(to_add["target"].value_counts())
shape_before = train.shape
train = pd.concat([train , to_add] , axis=0)
shape_after = train.shape
print(f"The shape of the train set have moved from {shape_before} => {shape_after}")
train.reset_index(drop = True , inplace =True , )
target 1 87148 0 57023 Name: count, dtype: int64 The shape of the train set have moved from (524429, 2) => (668600, 2)
from sklearn.metrics import log_loss
def get_score(y_true, y_pred):
num_classes = 2
total_log_loss = 0.0
y_true = np.array([[0, 1] if i == 1 else [1, 0] for i in y_true])
# print(y_true)
# print(y_pred)
for class_idx in range(num_classes):
class_true = y_true[:,class_idx]
class_pred = y_pred[:, class_idx]
class_log_loss = log_loss(class_true, class_pred)
total_log_loss += class_log_loss
return total_log_loss
# mean_log_loss = total_log_loss / num_classes
# return mean_log_loss
# def get_score(y_true, y_pred):
# # Ensure y_true and y_pred are 1D arrays
# y_true = y_true.flatten()
# y_pred = y_pred.flatten()
# # Calculate the log loss directly
# total_log_loss = log_loss(y_true, y_pred)
@contextmanager
def timer(name):
t0 = time.time()
LOGGER.info(f'[{name}] start')
yield
LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')
def init_logger(log_file=OUTPUT_DIR+'train.log'):
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=log_file)
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
LOGGER = init_logger()
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
seed_torch(seed=CFG.seed)
if CFG.debug:
CFG.epochs = 1
train = train.sample(n=10000, random_state=CFG.seed).reset_index(drop=True)
test = test.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)
files = glob(DIR_PATH+"/valset/*")
def len_txt(txt_file_path):
with open(txt_file_path) as f:
line_count = 0
for line in f:
line_count += 1
return line_count
print(f"The train file contains {len_txt(DIR_PATH +'/trainset_label.txt')} elements")
print(f"The test file contains {len_txt(DIR_PATH +'/valset_label.txt')} elements")
The train file contains 524430 elements The test file contains 147364 elements
# tkhalwidh
if txt_to_csv :
with open(DIR_PATH+"/trainset_label.txt") as f :
counter = 0
for line in tqdm(f , desc = "Collecting train set") :
if counter >= 1 :
l = line.strip().split(",")
new_row = {"img_name": l[0] , "target": l[1]}
train.loc[len(train)] = new_row
counter +=1
with open(DIR_PATH+"/valset_label.txt") as f :
counter = 0
for line in tqdm(f , desc = "Collecting test set") :
if counter >= 1 :
l = line.strip().split(",")
new_row = {"img_name": l[0] , "target": l[1]}
test.loc[len(test)] = new_row
counter +=1
sns.countplot(data = train , x = train["target"])
<Axes: xlabel='target', ylabel='count'>
class TrainDataset(Dataset) :
def __init__(self , df , transform = None) :
self.df = df
self.transform = transform
self.file_names = df["img_name"].values
self.labels = df["target"].values
def __len__(self) :
return len(self.df)
def __getitem__(self, idx):
file_name = self.file_names[idx]
# Check if the file is in the TRAIN_DIR or TEST_DIR
file_path_train = f'{TRAIN_DIR}/{file_name}'
file_path_test = f'{TEST_DIR}/{file_name}'
if os.path.exists(file_path_train):
file_path = file_path_train
elif os.path.exists(file_path_test):
file_path = file_path_test
else:
raise FileNotFoundError(f'File {file_name} not found in either TRAIN_DIR or TEST_DIR')
image = cv2.imread(file_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
label = torch.tensor(self.labels[idx]).long()
return image, label
def get_labels(self):
return list(self.labels)
class TestDataset(Dataset) :
def __init__(self , df , transform = None) :
self.df = df
self.transform = transform
self.file_names = df["img_name"].values
def __len__(self) :
return len(self.df)
def __getitem__(self , idx) :
file_name = self.file_names[idx]
file_path = f'{TEST_DIR}/{file_name}'
image = cv2.imread(file_path)
image = cv2.cvtColor(image , cv2.COLOR_BGR2RGB)
if self.transform :
augmented = self.transform(image=image)
image = augmented['image']
return image
train_dataset = TrainDataset(train)
fig, axes = plt.subplots(2, 4, figsize=(10, 7))
for i in range(2):
for j in range(4):
index = i * 3 + j
if index < len(train_dataset):
image, label = train_dataset[index]
axes[i, j].imshow(image)
if label.numpy() == 1:
axes[i, j].set_title("Fake", color="r")
else:
axes[i, j].set_title("Real", color="g")
axes[i, j].axis('off')
plt.tight_layout()
plt.show()
from albumentations import Compose, RandomBrightnessContrast, RandomCrop, \
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, ISONoise, MultiplicativeNoise, CoarseDropout, MedianBlur, Blur, GlassBlur, MotionBlur, \
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur, ToSepia, RandomShadow, RandomGamma, Rotate, Resize
from albumentations import RandomBrightnessContrast
from PIL import Image
# from transforms.albu import IsotropicResize, FFT, SR, DCT, CustomRandomCrop
import cv2
import numpy as np
import os
import imageio
import random
import cv2
import numpy as np
import torch
from albumentations import DualTransform, ImageOnlyTransform
from albumentations.augmentations.crops.transforms import Crop
from skimage.color import rgb2hsv, rgb2gray, rgb2yuv
from skimage import color, exposure, transform
from skimage.exposure import equalize_hist
from albumentations import RandomCrop
from scipy.fftpack import dct, idct
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
h, w = img.shape[:2]
if max(w, h) == size:
return img
if w > h:
scale = size / w
h = h * scale
w = size
else:
scale = size / h
w = w * scale
h = size
interpolation = interpolation_up if scale > 1 else interpolation_down
img = img.astype('uint8')
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
return resized
class IsotropicResize(DualTransform):
def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
always_apply=False, p=1):
super(IsotropicResize, self).__init__(always_apply, p)
self.max_side = max_side
self.interpolation_down = interpolation_down
self.interpolation_up = interpolation_up
def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
interpolation_up=interpolation_up)
def apply_to_mask(self, img, **params):
return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
def get_transform_init_args_names(self):
return ("max_side", "interpolation_down", "interpolation_up")
class Resize4xAndBack(ImageOnlyTransform):
def __init__(self, always_apply=False, p=0.5):
super(Resize4xAndBack, self).__init__(always_apply, p)
def apply(self, img, **params):
h, w = img.shape[:2]
scale = random.choice([2, 4])
img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
img = cv2.resize(img, (w, h),
interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
return img
class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
self.min_max_height = min_max_height
self.w2h_ratio = w2h_ratio
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
cropped = crop(img, x_min, y_min, x_max, y_max)
return cropped
@property
def targets_as_params(self):
return ["mask"]
def get_params_dependent_on_targets(self, params):
mask = params["mask"]
mask_height, mask_width = mask.shape[:2]
crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
w2h_ratio = random.uniform(*self.w2h_ratio)
crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
if mask.sum() == 0:
x_min = random.randint(0, mask_width - crop_width + 1)
y_min = random.randint(0, mask_height - crop_height + 1)
else:
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
non_zero_yx = np.argwhere(mask)
y, x = random.choice(non_zero_yx)
x_min = x - random.randint(0, crop_width - 1)
y_min = y - random.randint(0, crop_height - 1)
x_min = np.clip(x_min, 0, mask_width - crop_width)
y_min = np.clip(y_min, 0, mask_height - crop_height)
x_max = x_min + crop_height
y_max = y_min + crop_width
y_max = min(mask_height, y_max)
x_max = min(mask_width, x_max)
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
def get_transform_init_args_names(self):
return "min_max_height", "height", "width", "w2h_ratio"
class CustomRandomCrop(DualTransform):
def __init__(self, size, p=0.5) -> None:
super(CustomRandomCrop, self).__init__(p=p)
self.size = size
self.prob = p
def apply(self, img, copy=True, **params):
if img.shape[0] < self.size or img.shape[1] < self.size:
transform = IsotropicResize(max_side=self.size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR)
else:
transform = RandomCrop(self.size, self.size)
return np.asarray(transform(image=img)["image"])
class FFT(DualTransform):
def __init__(self, mode, p=0.5) -> None:
super(FFT, self).__init__(p=p)
self.prob = p
self.mode = mode
def apply(self, img, copy=True, **params):
dark_image_grey_fourier = np.fft.fftshift(np.fft.fft2(rgb2gray(img)))
mask = np.log(abs(dark_image_grey_fourier)).astype(np.uint8)
mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
if self.mode == 0:
return np.asarray(cv2.bitwise_and(img, img, mask=mask))
else:
mask = np.asarray(mask)
image = cv2.merge((mask, mask, mask))
return image
class SR(DualTransform):
def __init__(self, model_sr, p=0.5) -> None:
super(SR, self).__init__(p=p)
self.prob = p
self.model_sr = model_sr
def apply(self, img, copy=True, **params):
img = cv2.resize(img, (int(img.shape[1]/2), int(img.shape[0]/2)), interpolation = cv2.INTER_AREA)
img = np.transpose(img, (2, 0, 1))
img = torch.tensor(img, dtype=torch.float).unsqueeze(0).to(2)
sr_img = self.model_sr(img)
return sr_img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
class DCT(DualTransform):
def __init__(self, mode, p=0.5) -> None:
super(DCT, self).__init__(p=p)
self.prob = p
self.mode = mode
def rgb2gray(self, rgb):
return cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
def apply(self, img, copy=True, **params):
gray_img = self.rgb2gray(img)
dct_coefficients = cv2.dct(cv2.dct(np.float32(gray_img), flags=cv2.DCT_ROWS), flags=cv2.DCT_ROWS)
epsilon = 1
mask = np.log(np.abs(dct_coefficients) + epsilon).astype(np.uint8)
mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
if self.mode == 0:
return cv2.bitwise_and(img, img, mask=mask)
else:
dct_coefficients = np.asarray(dct_coefficients)
image = cv2.merge((dct_coefficients, dct_coefficients, dct_coefficients))
return image
import albumentations as A
def get_transforms(* , data) :
size = CFG.size
if data == 'train':
return Compose([
ImageCompression(quality_lower=40, quality_upper=100, p=0.1),
HorizontalFlip(),
GaussNoise(p=0.3),
ISONoise(p=0.3),
MultiplicativeNoise(p=0.3),
OneOf([
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
CustomRandomCrop(size=size)
], p=1),
Resize(height=size, width=size),
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT , value=0 , p=1),
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.5),
OneOf([CoarseDropout()], p=0.05),
ToGray(p=0.1),
ToSepia(p=0.05),
RandomShadow(p=0.05),
RandomGamma(p=0.1),
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
FFT(mode=0, p=0.05),
DCT(mode=1, p=0.5) ,
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2(),
])
elif data == 'valid':
return Compose([
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
Resize(CFG.size, CFG.size),
# PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT , value=0 ),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2(),
])
train_dataset = TrainDataset(train , transform= get_transforms(data = "train"))
fig, axes = plt.subplots(2, 4, figsize=(10, 7))
for i in range(2):
for j in range(4):
index = i * 3 + j
if index < len(train_dataset):
image, label = train_dataset[index]
axes[i, j].imshow(image.permute(1,2,0))
if label.numpy() == 1:
axes[i, j].set_title("Fake", color="r")
else:
axes[i, j].set_title("Real", color="g")
axes[i, j].axis('off')
plt.tight_layout()
plt.show()
folds = train.copy()
Fold = KFold(n_splits = CFG.n_fold , shuffle = True , random_state = CFG.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.target_col])):
folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
model = Two_Stream_Net()
model(train_dataset[0][0].unsqueeze(1).permute(1,0,2,3))
Using dropout 0.5 Using dropout 0.5
tensor([[-1.3592, 0.0468]], grad_fn=<AddmmBackward0>)
import wandb
try:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("wandb_api")
wandb.login(key=api_key)
anonymous = None
except:
anonymous = "must"
print('To use your W&B account,\nGo to Add-ons -> Secrets and provide your W&B access token. Use the Label name as WANDB. \nGet your W&B access token from here: https://wandb.ai/authorize')
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly. wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line. wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
run = wandb.init(entity = 'lassouedaymenla',
project = 'tutorial',
save_code = True,
name = "FaceForgery"
)
wandb: Currently logged in as: lassouedaymenla. Use `wandb login --relogin` to force relogin
/kaggle/working/wandb/run-20240722_172435-gi80kvwo
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
scores = AverageMeter()
# switch to train mode
model.train()
start = end = time.time()
global_step = 0
for step, (images, labels) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
images = images.to(device)
labels = labels.to(device)
batch_size = labels.size(0)
y_preds = model(images)
labels = labels.cuda()
y_preds = y_preds.cuda()
# debug
# print(torch.nn.functional.softmax(y_preds, dim=1))
# print(labels)
loss = criterion(y_preds, labels)
# record loss
losses.update(loss.item(), batch_size)
# # Logging to wandb
# wandb.log({"Training Loss": loss.item(), "Epoch": epoch, "Step": global_step})
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
if CFG.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
if (step + 1) % CFG.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
global_step += 1
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
print('Epoch: [{0}][{1}/{2}] '
'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
'Elapsed {remain:s} '
'Loss: {loss.val:.4f}({loss.avg:.4f}) '
'Grad: {grad_norm:.4f} '
#'LR: {lr:.6f} '
.format(
epoch+1, step, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses,
remain=timeSince(start, float(step+1)/len(train_loader)),
grad_norm=grad_norm,
#lr=scheduler.get_lr()[0],
))
# # Log epoch summary to wandb
# wandb.log({"Epoch Training Loss": losses.avg, "Epoch": epoch})
wandb.log({
"Train Loss": losses.val,
"Step": step,
"Gradient Norm": grad_norm,
"Learning Rate": optimizer.param_groups[0]['lr'] # Add this line to log the learning rate
})
return losses.avg
def valid_fn(valid_loader, model, criterion, device):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
scores = AverageMeter()
# switch to evaluation mode
model.eval()
preds = []
start = end = time.time()
for step, (images, labels) in enumerate(valid_loader):
# measure data loading time
data_time.update(time.time() - end)
images = images.to(device)
labels = labels.to(device)
batch_size = labels.size(0)
# compute loss
with torch.no_grad():
y_preds = model(images)
labels = labels.cuda()
y_preds = y_preds.cuda()
loss = criterion(y_preds, labels)
losses.update(loss.item(), batch_size)
y_preds = torch.nn.functional.softmax(y_preds, dim=1)
# record accuracy
y_preds = y_preds.to('cpu').numpy()
preds.append(y_preds)
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
print('EVAL: [{0}/{1}] '
'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
'Elapsed {remain:s} '
'Loss: {loss.val:.4f}({loss.avg:.4f}) '
.format(
step, len(valid_loader), batch_time=batch_time,
data_time=data_time, loss=losses,
remain=timeSince(start, float(step+1)/len(valid_loader)),
))
wandb.log({
"Val Loss ": losses.val,
"Val Step": step ,
})
predictions = np.concatenate(preds)
return losses.avg, predictions
def inference(model, states, test_loader, device):
model.to(device)
tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
probs = []
for i, (images) in tk0:
images = images.to(device)
avg_preds = []
for state in states:
model.load_state_dict(state['model'])
model.eval()
with torch.no_grad():
# print(images.shape)
y_preds = model(images)
avg_preds.append(y_preds.to('cpu').numpy())
avg_preds = np.mean(avg_preds, axis=0)
probs.append(avg_preds)
probs = np.concatenate(probs)
return probs
def train_loop(folds, fold):
LOGGER.info(f"========== fold: {fold} training ==========")
# ====================================================
# loader
# ====================================================
trn_idx = folds[folds['fold'] != fold].index
val_idx = folds[folds['fold'] == fold].index
train_folds = folds.loc[trn_idx].reset_index(drop=True)
valid_folds = folds.loc[val_idx].reset_index(drop=True)
train_dataset = TrainDataset(train_folds,
transform=get_transforms(data='train'))
valid_dataset = TrainDataset(valid_folds,
transform=get_transforms(data='valid'))
train_loader = DataLoader(train_dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=CFG.num_workers, sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="upsampling") ,
pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
# ====================================================
# scheduler
# ====================================================
def get_scheduler(optimizer):
if CFG.scheduler=='ReduceLROnPlateau':
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
elif CFG.scheduler=='CosineAnnealingLR':
scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
elif CFG.scheduler=='CosineAnnealingWarmRestarts':
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
return scheduler
# ====================================================
# model & optimizer
# ====================================================
model = Two_Stream_Net()
model.to(device)
optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
scheduler = get_scheduler(optimizer)
# ====================================================
# apex
# ====================================================
if CFG.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
# ====================================================
# loop
# ====================================================
criterion = nn.CrossEntropyLoss().cuda()
best_score = 50000
best_loss = np.inf
for epoch in range(CFG.epochs):
start_time = time.time()
# train
avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
# eval
avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)
valid_labels = valid_folds[CFG.target_col].values
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(avg_val_loss)
elif isinstance(scheduler, CosineAnnealingLR):
scheduler.step()
elif isinstance(scheduler, CosineAnnealingWarmRestarts):
scheduler.step()
# scoring
score = get_score(valid_labels, preds)
print(score)
preds= torch.nn.functional.softmax(torch.from_numpy(preds), dim=1).numpy()[:,1]
score2 = roc_auc_score(valid_labels, preds)
elapsed = time.time() - start_time
LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s') #.info makes the msg shows in red cadre
LOGGER.info(f'Epoch {epoch+1} - LogLoss: {score} - AUC: {score2}')
if score < best_score:
best_score = score
LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
torch.save({'model': model.state_dict(),
'preds': preds},
OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
#valid_folds[[str(c) for c in range(5)]] = check_point['preds']
#valid_folds['preds'] = check_point['preds'].argmax(1)
return
def main():
"""
Prepare: 1.train 2.test 3.submission 4.folds
"""
def get_result(result_df):
preds = result_df['preds'].values
labels = result_df[CFG.target_col].values
score = get_score(labels, preds)
LOGGER.info(f'Score: {score:<.5f}')
if CFG.train:
# train
oof_df = pd.DataFrame()
for fold in range(CFG.n_fold):
if fold in CFG.trn_fold:
train_loop(folds, fold)
#oof_df = pd.concat([oof_df, _oof_df])
#LOGGER.info(f"========== fold: {fold} result ==========")
#get_result(_oof_df)
# CV result
LOGGER.info(f"========== CV ==========")
#get_result(oof_df)
# save result
#oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)
if CFG.inference:
# inference
model = CustomResNext(CFG.model_name, pretrained=False)
states = [torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth') for fold in CFG.trn_fold]
test_dataset = TestDataset(test, transform=get_transforms(data='valid'))
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False,
num_workers=CFG.num_workers, pin_memory=True)
predictions = inference(model, states, test_loader, device)
# submission
print(predictions)
test['label'] = torch.nn.functional.softmax(torch.from_numpy(predictions), dim=1).numpy()[:,1]
print(test['label'])
test[['img_name', 'label']].to_csv(OUTPUT_DIR+'submission.csv', index=False)
if __name__ == '__main__':
main()
========== fold: 0 training ==========
Using dropout 0.5 Using dropout 0.5 Epoch: [1][0/40998] Data 1.775 (1.775) Elapsed 0m 4s (remain 3105m 11s) Loss: 0.8732(0.8732) Grad: 13.5833 Epoch: [1][20/40998] Data 0.552 (0.585) Elapsed 0m 21s (remain 684m 25s) Loss: 0.7016(0.8085) Grad: 8.5844 Epoch: [1][40/40998] Data 0.552 (0.569) Elapsed 0m 38s (remain 634m 19s) Loss: 0.8539(0.7425) Grad: 10.9534 Epoch: [1][60/40998] Data 0.553 (0.564) Elapsed 0m 55s (remain 616m 52s) Loss: 0.6197(0.7149) Grad: 7.0083 Epoch: [1][80/40998] Data 0.553 (0.561) Elapsed 1m 12s (remain 607m 52s) Loss: 0.5332(0.6929) Grad: 7.3943 Epoch: [1][100/40998] Data 0.553 (0.559) Elapsed 1m 29s (remain 602m 20s) Loss: 0.5192(0.6752) Grad: 6.0233 Epoch: [1][120/40998] Data 0.552 (0.558) Elapsed 1m 46s (remain 598m 32s) Loss: 0.6919(0.6608) Grad: 11.2926 Epoch: [1][140/40998] Data 0.553 (0.558) Elapsed 2m 3s (remain 595m 44s) Loss: 0.5706(0.6463) Grad: 6.2257 Epoch: [1][160/40998] Data 0.553 (0.557) Elapsed 2m 20s (remain 593m 34s) Loss: 0.7004(0.6316) Grad: 7.9941 Epoch: [1][180/40998] Data 0.552 (0.556) Elapsed 2m 37s (remain 591m 49s) Loss: 0.5212(0.6236) Grad: 4.7840 Epoch: [1][200/40998] Data 0.553 (0.556) Elapsed 2m 54s (remain 590m 21s) Loss: 0.6489(0.6154) Grad: 6.1323 Epoch: [1][220/40998] Data 0.552 (0.556) Elapsed 3m 11s (remain 589m 6s) Loss: 0.5790(0.6054) Grad: 7.4238 Epoch: [1][240/40998] Data 0.551 (0.556) Elapsed 3m 28s (remain 588m 0s) Loss: 0.6314(0.5980) Grad: 7.0559 Epoch: [1][260/40998] Data 0.553 (0.555) Elapsed 3m 45s (remain 587m 2s) Loss: 0.4474(0.5893) Grad: 4.5827 Epoch: [1][280/40998] Data 0.552 (0.555) Elapsed 4m 2s (remain 586m 10s) Loss: 0.2868(0.5788) Grad: 3.7827 Epoch: [1][300/40998] Data 0.553 (0.555) Elapsed 4m 19s (remain 585m 22s) Loss: 0.3590(0.5748) Grad: 3.8241 Epoch: [1][320/40998] Data 0.553 (0.555) Elapsed 4m 36s (remain 584m 39s) Loss: 0.4217(0.5700) Grad: 4.3578 Epoch: [1][340/40998] Data 0.553 (0.555) Elapsed 4m 53s (remain 583m 58s) Loss: 0.5672(0.5651) Grad: 5.0714 Epoch: [1][360/40998] Data 0.553 (0.555) Elapsed 5m 10s (remain 583m 20s) Loss: 0.5896(0.5608) Grad: 6.0533 Epoch: [1][380/40998] Data 0.552 (0.554) Elapsed 5m 27s (remain 582m 44s) Loss: 0.6467(0.5573) Grad: 6.7685 Epoch: [1][400/40998] Data 0.553 (0.554) Elapsed 5m 45s (remain 582m 10s) Loss: 0.2545(0.5520) Grad: 4.3219 Epoch: [1][420/40998] Data 0.553 (0.554) Elapsed 6m 2s (remain 581m 38s) Loss: 0.4342(0.5507) Grad: 3.2260 Epoch: [1][440/40998] Data 0.552 (0.554) Elapsed 6m 19s (remain 581m 8s) Loss: 0.5479(0.5483) Grad: 5.4349 Epoch: [1][460/40998] Data 0.553 (0.554) Elapsed 6m 36s (remain 580m 37s) Loss: 0.4718(0.5442) Grad: 4.3204 Epoch: [1][480/40998] Data 0.553 (0.554) Elapsed 6m 53s (remain 580m 9s) Loss: 0.4894(0.5410) Grad: 4.7152 Epoch: [1][500/40998] Data 0.552 (0.554) Elapsed 7m 10s (remain 579m 41s) Loss: 0.4900(0.5373) Grad: 4.7308 Epoch: [1][520/40998] Data 0.552 (0.554) Elapsed 7m 27s (remain 579m 15s) Loss: 0.2977(0.5369) Grad: 2.6538 Epoch: [1][540/40998] Data 0.553 (0.554) Elapsed 7m 44s (remain 578m 48s) Loss: 0.3365(0.5333) Grad: 3.6212 Epoch: [1][560/40998] Data 0.553 (0.554) Elapsed 8m 1s (remain 578m 23s) Loss: 0.7219(0.5302) Grad: 8.6209 Epoch: [1][580/40998] Data 0.553 (0.554) Elapsed 8m 18s (remain 577m 58s) Loss: 0.4107(0.5271) Grad: 3.9289 Epoch: [1][600/40998] Data 0.553 (0.554) Elapsed 8m 35s (remain 577m 33s) Loss: 0.4708(0.5250) Grad: 4.1229 Epoch: [1][620/40998] Data 0.553 (0.554) Elapsed 8m 52s (remain 577m 9s) Loss: 0.4260(0.5233) Grad: 4.5030 Epoch: [1][640/40998] Data 0.552 (0.554) Elapsed 9m 9s (remain 576m 45s) Loss: 0.7291(0.5207) Grad: 6.3236 Epoch: [1][660/40998] Data 0.553 (0.554) Elapsed 9m 26s (remain 576m 22s) Loss: 0.4016(0.5187) Grad: 3.1535 Epoch: [1][680/40998] Data 0.553 (0.554) Elapsed 9m 43s (remain 575m 59s) Loss: 0.3967(0.5166) Grad: 3.5701 Epoch: [1][700/40998] Data 0.553 (0.554) Elapsed 10m 0s (remain 575m 37s) Loss: 0.4938(0.5147) Grad: 5.0574 Epoch: [1][720/40998] Data 0.552 (0.554) Elapsed 10m 17s (remain 575m 15s) Loss: 0.4573(0.5133) Grad: 4.9107 Epoch: [1][740/40998] Data 0.553 (0.554) Elapsed 10m 34s (remain 574m 53s) Loss: 0.3223(0.5107) Grad: 3.1092 Epoch: [1][760/40998] Data 0.553 (0.554) Elapsed 10m 51s (remain 574m 32s) Loss: 0.2485(0.5070) Grad: 3.5340 Epoch: [1][780/40998] Data 0.553 (0.554) Elapsed 11m 9s (remain 574m 10s) Loss: 0.4243(0.5047) Grad: 5.0002 Epoch: [1][800/40998] Data 0.553 (0.554) Elapsed 11m 26s (remain 573m 49s) Loss: 0.5623(0.5030) Grad: 5.6730 Epoch: [1][820/40998] Data 0.553 (0.554) Elapsed 11m 43s (remain 573m 28s) Loss: 0.3768(0.5005) Grad: 3.8652 Epoch: [1][840/40998] Data 0.553 (0.554) Elapsed 12m 0s (remain 573m 8s) Loss: 0.3171(0.4984) Grad: 4.0671 Epoch: [1][860/40998] Data 0.553 (0.554) Elapsed 12m 17s (remain 572m 47s) Loss: 0.5697(0.4968) Grad: 5.0413 Epoch: [1][880/40998] Data 0.553 (0.554) Elapsed 12m 34s (remain 572m 27s) Loss: 0.2385(0.4948) Grad: 3.0250 Epoch: [1][900/40998] Data 0.553 (0.554) Elapsed 12m 51s (remain 572m 6s) Loss: 0.3619(0.4934) Grad: 4.8707 Epoch: [1][920/40998] Data 0.553 (0.554) Elapsed 13m 8s (remain 571m 46s) Loss: 0.4019(0.4920) Grad: 2.8238 Epoch: [1][940/40998] Data 0.552 (0.553) Elapsed 13m 25s (remain 571m 29s) Loss: 0.6209(0.4903) Grad: 5.1697 Epoch: [1][960/40998] Data 0.552 (0.553) Elapsed 13m 42s (remain 571m 9s) Loss: 0.4208(0.4884) Grad: 3.7733 Epoch: [1][980/40998] Data 0.553 (0.553) Elapsed 13m 59s (remain 570m 49s) Loss: 0.4393(0.4873) Grad: 3.6124 Epoch: [1][1000/40998] Data 0.552 (0.553) Elapsed 14m 16s (remain 570m 30s) Loss: 0.4038(0.4851) Grad: 4.7643 Epoch: [1][1020/40998] Data 0.553 (0.553) Elapsed 14m 33s (remain 570m 11s) Loss: 0.3471(0.4838) Grad: 3.6725 Epoch: [1][1040/40998] Data 0.552 (0.553) Elapsed 14m 50s (remain 569m 52s) Loss: 0.4080(0.4822) Grad: 4.1811 Epoch: [1][1060/40998] Data 0.552 (0.553) Elapsed 15m 7s (remain 569m 32s) Loss: 0.4326(0.4809) Grad: 4.6237 Epoch: [1][1080/40998] Data 0.553 (0.553) Elapsed 15m 24s (remain 569m 13s) Loss: 0.3635(0.4792) Grad: 3.5998 Epoch: [1][1100/40998] Data 0.553 (0.553) Elapsed 15m 41s (remain 568m 54s) Loss: 0.2735(0.4782) Grad: 2.8812 Epoch: [1][1120/40998] Data 0.553 (0.553) Elapsed 15m 59s (remain 568m 36s) Loss: 0.2919(0.4769) Grad: 4.5212 Epoch: [1][1140/40998] Data 0.552 (0.553) Elapsed 16m 16s (remain 568m 17s) Loss: 0.4960(0.4759) Grad: 4.6030 Epoch: [1][1160/40998] Data 0.553 (0.553) Elapsed 16m 33s (remain 567m 58s) Loss: 0.5470(0.4740) Grad: 7.2479 Epoch: [1][1180/40998] Data 0.553 (0.553) Elapsed 16m 50s (remain 567m 39s) Loss: 0.5810(0.4730) Grad: 5.1567 Epoch: [1][1200/40998] Data 0.552 (0.553) Elapsed 17m 7s (remain 567m 21s) Loss: 0.3351(0.4714) Grad: 4.3528 Epoch: [1][1220/40998] Data 0.552 (0.553) Elapsed 17m 24s (remain 567m 2s) Loss: 0.3622(0.4700) Grad: 3.8108 Epoch: [1][1240/40998] Data 0.552 (0.553) Elapsed 17m 41s (remain 566m 44s) Loss: 0.4086(0.4690) Grad: 4.0388 Epoch: [1][1260/40998] Data 0.553 (0.553) Elapsed 17m 58s (remain 566m 25s) Loss: 0.2695(0.4673) Grad: 2.8184 Epoch: [1][1280/40998] Data 0.553 (0.553) Elapsed 18m 15s (remain 566m 7s) Loss: 0.4850(0.4659) Grad: 4.4943 Epoch: [1][1300/40998] Data 0.553 (0.553) Elapsed 18m 32s (remain 565m 48s) Loss: 0.5135(0.4646) Grad: 6.2567 Epoch: [1][1320/40998] Data 0.553 (0.553) Elapsed 18m 49s (remain 565m 30s) Loss: 0.2715(0.4629) Grad: 2.6880 Epoch: [1][1340/40998] Data 0.553 (0.553) Elapsed 19m 6s (remain 565m 12s) Loss: 0.3877(0.4619) Grad: 3.2289 Epoch: [1][1360/40998] Data 0.553 (0.553) Elapsed 19m 23s (remain 564m 53s) Loss: 0.3231(0.4608) Grad: 3.7210 Epoch: [1][1380/40998] Data 0.553 (0.553) Elapsed 19m 40s (remain 564m 35s) Loss: 0.4805(0.4588) Grad: 5.1187 Epoch: [1][1400/40998] Data 0.553 (0.553) Elapsed 19m 57s (remain 564m 17s) Loss: 0.1662(0.4573) Grad: 3.3010 Epoch: [1][1420/40998] Data 0.553 (0.553) Elapsed 20m 14s (remain 563m 59s) Loss: 0.3386(0.4555) Grad: 4.0177 Epoch: [1][1440/40998] Data 0.553 (0.553) Elapsed 20m 32s (remain 563m 41s) Loss: 0.6130(0.4542) Grad: 6.4697 Epoch: [1][1460/40998] Data 0.553 (0.553) Elapsed 20m 49s (remain 563m 23s) Loss: 0.2941(0.4521) Grad: 3.5293 Epoch: [1][1480/40998] Data 0.553 (0.553) Elapsed 21m 6s (remain 563m 4s) Loss: 0.4725(0.4512) Grad: 6.7537 Epoch: [1][1500/40998] Data 0.552 (0.553) Elapsed 21m 23s (remain 562m 46s) Loss: 0.5451(0.4503) Grad: 5.2309 Epoch: [1][1520/40998] Data 0.553 (0.553) Elapsed 21m 40s (remain 562m 28s) Loss: 0.2884(0.4493) Grad: 3.9010 Epoch: [1][1540/40998] Data 0.552 (0.553) Elapsed 21m 57s (remain 562m 10s) Loss: 0.3341(0.4481) Grad: 2.8484 Epoch: [1][1560/40998] Data 0.552 (0.553) Elapsed 22m 14s (remain 561m 52s) Loss: 0.2223(0.4467) Grad: 2.4516 Epoch: [1][1580/40998] Data 0.553 (0.553) Elapsed 22m 31s (remain 561m 34s) Loss: 0.2608(0.4451) Grad: 4.6336 Epoch: [1][1600/40998] Data 0.552 (0.553) Elapsed 22m 48s (remain 561m 17s) Loss: 0.3960(0.4441) Grad: 3.7286 Epoch: [1][1620/40998] Data 0.552 (0.553) Elapsed 23m 5s (remain 560m 59s) Loss: 0.5299(0.4433) Grad: 6.1462 Epoch: [1][1640/40998] Data 0.553 (0.553) Elapsed 23m 22s (remain 560m 41s) Loss: 0.1687(0.4425) Grad: 1.7323 Epoch: [1][1660/40998] Data 0.553 (0.553) Elapsed 23m 39s (remain 560m 23s) Loss: 0.2183(0.4418) Grad: 2.6848 Epoch: [1][1680/40998] Data 0.553 (0.553) Elapsed 23m 56s (remain 560m 5s) Loss: 0.4905(0.4404) Grad: 4.7968 Epoch: [1][1700/40998] Data 0.553 (0.553) Elapsed 24m 13s (remain 559m 47s) Loss: 0.4574(0.4398) Grad: 4.9223 Epoch: [1][1720/40998] Data 0.553 (0.553) Elapsed 24m 30s (remain 559m 29s) Loss: 0.2995(0.4394) Grad: 3.5081 Epoch: [1][1740/40998] Data 0.553 (0.553) Elapsed 24m 47s (remain 559m 11s) Loss: 0.4469(0.4380) Grad: 3.8852 Epoch: [1][1760/40998] Data 0.553 (0.553) Elapsed 25m 5s (remain 558m 54s) Loss: 0.4500(0.4374) Grad: 3.9600 Epoch: [1][1780/40998] Data 0.552 (0.553) Elapsed 25m 22s (remain 558m 36s) Loss: 0.8086(0.4368) Grad: 5.9817 Epoch: [1][1800/40998] Data 0.552 (0.553) Elapsed 25m 39s (remain 558m 18s) Loss: 0.2335(0.4357) Grad: 2.5538 Epoch: [1][1820/40998] Data 0.553 (0.553) Elapsed 25m 56s (remain 558m 0s) Loss: 0.2386(0.4345) Grad: 3.7621 Epoch: [1][1840/40998] Data 0.553 (0.553) Elapsed 26m 13s (remain 557m 43s) Loss: 0.1864(0.4334) Grad: 2.3264 Epoch: [1][1860/40998] Data 0.551 (0.553) Elapsed 26m 30s (remain 557m 25s) Loss: 0.1028(0.4325) Grad: 1.5620 Epoch: [1][1880/40998] Data 0.553 (0.553) Elapsed 26m 47s (remain 557m 7s) Loss: 0.1165(0.4315) Grad: 1.5070 Epoch: [1][1900/40998] Data 0.553 (0.553) Elapsed 27m 4s (remain 556m 50s) Loss: 0.6323(0.4310) Grad: 4.1203 Epoch: [1][1920/40998] Data 0.553 (0.553) Elapsed 27m 21s (remain 556m 32s) Loss: 0.4528(0.4302) Grad: 4.6585 Epoch: [1][1940/40998] Data 0.552 (0.553) Elapsed 27m 38s (remain 556m 14s) Loss: 0.2887(0.4288) Grad: 3.0704 Epoch: [1][1960/40998] Data 0.552 (0.553) Elapsed 27m 55s (remain 555m 56s) Loss: 0.2247(0.4284) Grad: 2.3573 Epoch: [1][1980/40998] Data 0.553 (0.553) Elapsed 28m 12s (remain 555m 39s) Loss: 0.4611(0.4277) Grad: 3.7604 Epoch: [1][2000/40998] Data 0.553 (0.553) Elapsed 28m 29s (remain 555m 21s) Loss: 0.3091(0.4266) Grad: 3.2972 Epoch: [1][2020/40998] Data 0.552 (0.553) Elapsed 28m 46s (remain 555m 3s) Loss: 0.4423(0.4258) Grad: 3.3024 Epoch: [1][2040/40998] Data 0.553 (0.553) Elapsed 29m 3s (remain 554m 46s) Loss: 0.3166(0.4249) Grad: 4.1661 Epoch: [1][2060/40998] Data 0.553 (0.553) Elapsed 29m 20s (remain 554m 28s) Loss: 0.2119(0.4242) Grad: 3.0091 Epoch: [1][2080/40998] Data 0.553 (0.553) Elapsed 29m 38s (remain 554m 11s) Loss: 0.1845(0.4236) Grad: 2.2240 Epoch: [1][2100/40998] Data 0.553 (0.553) Elapsed 29m 55s (remain 553m 53s) Loss: 0.1722(0.4229) Grad: 1.4631 Epoch: [1][2120/40998] Data 0.553 (0.553) Elapsed 30m 12s (remain 553m 35s) Loss: 0.4426(0.4222) Grad: 3.5704 Epoch: [1][2140/40998] Data 0.552 (0.553) Elapsed 30m 29s (remain 553m 18s) Loss: 0.4389(0.4214) Grad: 4.6936 Epoch: [1][2160/40998] Data 0.553 (0.553) Elapsed 30m 46s (remain 553m 0s) Loss: 0.3045(0.4207) Grad: 3.4256 Epoch: [1][2180/40998] Data 0.553 (0.553) Elapsed 31m 3s (remain 552m 43s) Loss: 0.2534(0.4202) Grad: 3.2902 Epoch: [1][2200/40998] Data 0.553 (0.553) Elapsed 31m 20s (remain 552m 25s) Loss: 0.1792(0.4189) Grad: 2.0067 Epoch: [1][2220/40998] Data 0.555 (0.553) Elapsed 31m 37s (remain 552m 8s) Loss: 0.2259(0.4180) Grad: 3.2194 Epoch: [1][2240/40998] Data 0.553 (0.553) Elapsed 31m 54s (remain 551m 50s) Loss: 0.2489(0.4174) Grad: 3.8718 Epoch: [1][2260/40998] Data 0.553 (0.553) Elapsed 32m 11s (remain 551m 33s) Loss: 0.2170(0.4167) Grad: 2.2986 Epoch: [1][2280/40998] Data 0.553 (0.553) Elapsed 32m 28s (remain 551m 15s) Loss: 0.4465(0.4161) Grad: 3.4946 Epoch: [1][2300/40998] Data 0.553 (0.553) Elapsed 32m 45s (remain 550m 58s) Loss: 0.3473(0.4154) Grad: 3.8170 Epoch: [1][2320/40998] Data 0.553 (0.553) Elapsed 33m 2s (remain 550m 40s) Loss: 0.4576(0.4147) Grad: 3.2494 Epoch: [1][2340/40998] Data 0.551 (0.553) Elapsed 33m 19s (remain 550m 23s) Loss: 0.3566(0.4137) Grad: 2.9380 Epoch: [1][2360/40998] Data 0.552 (0.553) Elapsed 33m 36s (remain 550m 5s) Loss: 0.1940(0.4128) Grad: 2.9704 Epoch: [1][2380/40998] Data 0.553 (0.553) Elapsed 33m 53s (remain 549m 48s) Loss: 0.1885(0.4123) Grad: 2.3402 Epoch: [1][2400/40998] Data 0.553 (0.553) Elapsed 34m 11s (remain 549m 30s) Loss: 0.3742(0.4118) Grad: 3.9260 Epoch: [1][2420/40998] Data 0.553 (0.553) Elapsed 34m 28s (remain 549m 13s) Loss: 0.4667(0.4115) Grad: 3.1082 Epoch: [1][2440/40998] Data 0.552 (0.553) Elapsed 34m 45s (remain 548m 55s) Loss: 0.3068(0.4108) Grad: 3.4124 Epoch: [1][2460/40998] Data 0.553 (0.553) Elapsed 35m 2s (remain 548m 38s) Loss: 0.1703(0.4099) Grad: 2.3023 Epoch: [1][2480/40998] Data 0.553 (0.553) Elapsed 35m 19s (remain 548m 20s) Loss: 0.3062(0.4092) Grad: 3.6527 Epoch: [1][2500/40998] Data 0.553 (0.553) Elapsed 35m 36s (remain 548m 3s) Loss: 0.5162(0.4085) Grad: 5.0258 Epoch: [1][2520/40998] Data 0.553 (0.553) Elapsed 35m 53s (remain 547m 46s) Loss: 0.3871(0.4076) Grad: 3.3546 Epoch: [1][2540/40998] Data 0.552 (0.553) Elapsed 36m 10s (remain 547m 28s) Loss: 0.1952(0.4071) Grad: 2.2715 Epoch: [1][2560/40998] Data 0.552 (0.553) Elapsed 36m 27s (remain 547m 11s) Loss: 0.2784(0.4067) Grad: 1.9070 Epoch: [1][2580/40998] Data 0.553 (0.553) Elapsed 36m 44s (remain 546m 53s) Loss: 0.2288(0.4060) Grad: 2.5700 Epoch: [1][2600/40998] Data 0.553 (0.553) Elapsed 37m 1s (remain 546m 36s) Loss: 0.8520(0.4052) Grad: 8.7013 Epoch: [1][2620/40998] Data 0.552 (0.553) Elapsed 37m 18s (remain 546m 18s) Loss: 0.2892(0.4047) Grad: 2.6977 Epoch: [1][2640/40998] Data 0.552 (0.553) Elapsed 37m 35s (remain 546m 1s) Loss: 0.2398(0.4041) Grad: 1.7540 Epoch: [1][2660/40998] Data 0.553 (0.553) Elapsed 37m 52s (remain 545m 44s) Loss: 0.2134(0.4033) Grad: 1.8970 Epoch: [1][2680/40998] Data 0.553 (0.553) Elapsed 38m 9s (remain 545m 26s) Loss: 0.1671(0.4026) Grad: 2.1689 Epoch: [1][2700/40998] Data 0.553 (0.553) Elapsed 38m 26s (remain 545m 9s) Loss: 0.1771(0.4021) Grad: 1.6560 Epoch: [1][2720/40998] Data 0.553 (0.553) Elapsed 38m 43s (remain 544m 51s) Loss: 0.2230(0.4016) Grad: 2.3146 Epoch: [1][2740/40998] Data 0.553 (0.553) Elapsed 39m 1s (remain 544m 34s) Loss: 0.2867(0.4009) Grad: 2.9328 Epoch: [1][2760/40998] Data 0.553 (0.553) Elapsed 39m 18s (remain 544m 17s) Loss: 0.2763(0.4005) Grad: 3.3305 Epoch: [1][2780/40998] Data 0.553 (0.553) Elapsed 39m 35s (remain 543m 59s) Loss: 0.2568(0.4001) Grad: 2.3604 Epoch: [1][2800/40998] Data 0.553 (0.553) Elapsed 39m 52s (remain 543m 42s) Loss: 0.3026(0.3994) Grad: 4.7383 Epoch: [1][2820/40998] Data 0.552 (0.553) Elapsed 40m 9s (remain 543m 25s) Loss: 0.2071(0.3989) Grad: 2.3130 Epoch: [1][2840/40998] Data 0.553 (0.553) Elapsed 40m 26s (remain 543m 7s) Loss: 0.3456(0.3981) Grad: 3.2836 Epoch: [1][2860/40998] Data 0.553 (0.553) Elapsed 40m 43s (remain 542m 50s) Loss: 0.2130(0.3973) Grad: 2.6025 Epoch: [1][2880/40998] Data 0.552 (0.553) Elapsed 41m 0s (remain 542m 33s) Loss: 0.4571(0.3966) Grad: 5.9418 Epoch: [1][2900/40998] Data 0.553 (0.553) Elapsed 41m 17s (remain 542m 15s) Loss: 0.2048(0.3962) Grad: 3.8048 Epoch: [1][2920/40998] Data 0.553 (0.553) Elapsed 41m 34s (remain 541m 58s) Loss: 0.2103(0.3954) Grad: 3.1009 Epoch: [1][2940/40998] Data 0.553 (0.553) Elapsed 41m 51s (remain 541m 41s) Loss: 0.4970(0.3944) Grad: 4.4819 Epoch: [1][2960/40998] Data 0.553 (0.553) Elapsed 42m 8s (remain 541m 23s) Loss: 0.4697(0.3940) Grad: 3.6993 Epoch: [1][2980/40998] Data 0.552 (0.553) Elapsed 42m 25s (remain 541m 6s) Loss: 0.4027(0.3937) Grad: 2.8779 Epoch: [1][3000/40998] Data 0.552 (0.553) Elapsed 42m 42s (remain 540m 49s) Loss: 0.3300(0.3933) Grad: 3.1115 Epoch: [1][3020/40998] Data 0.553 (0.553) Elapsed 42m 59s (remain 540m 31s) Loss: 0.4063(0.3928) Grad: 5.5049 Epoch: [1][3040/40998] Data 0.553 (0.553) Elapsed 43m 16s (remain 540m 14s) Loss: 0.2783(0.3925) Grad: 3.4340 Epoch: [1][3060/40998] Data 0.552 (0.553) Elapsed 43m 34s (remain 539m 57s) Loss: 0.1879(0.3920) Grad: 2.0003 Epoch: [1][3080/40998] Data 0.553 (0.553) Elapsed 43m 51s (remain 539m 39s) Loss: 0.2985(0.3916) Grad: 3.6128 Epoch: [1][3100/40998] Data 0.552 (0.553) Elapsed 44m 8s (remain 539m 22s) Loss: 0.2297(0.3912) Grad: 3.2206 Epoch: [1][3120/40998] Data 0.553 (0.553) Elapsed 44m 25s (remain 539m 5s) Loss: 0.6284(0.3911) Grad: 5.7342 Epoch: [1][3140/40998] Data 0.553 (0.553) Elapsed 44m 42s (remain 538m 47s) Loss: 0.4407(0.3906) Grad: 4.5784 Epoch: [1][3160/40998] Data 0.552 (0.553) Elapsed 44m 59s (remain 538m 30s) Loss: 0.1959(0.3900) Grad: 1.8897 Epoch: [1][3180/40998] Data 0.552 (0.553) Elapsed 45m 16s (remain 538m 13s) Loss: 0.3527(0.3895) Grad: 4.0058 Epoch: [1][3200/40998] Data 0.553 (0.553) Elapsed 45m 33s (remain 537m 55s) Loss: 0.4527(0.3890) Grad: 4.5512 Epoch: [1][3220/40998] Data 0.552 (0.553) Elapsed 45m 50s (remain 537m 38s) Loss: 0.4978(0.3886) Grad: 3.6290 Epoch: [1][3240/40998] Data 0.553 (0.553) Elapsed 46m 7s (remain 537m 21s) Loss: 0.2094(0.3880) Grad: 2.2683 Epoch: [1][3260/40998] Data 0.553 (0.553) Elapsed 46m 24s (remain 537m 4s) Loss: 0.3202(0.3876) Grad: 3.4282 Epoch: [1][3280/40998] Data 0.553 (0.553) Elapsed 46m 41s (remain 536m 46s) Loss: 0.2276(0.3870) Grad: 2.5730 Epoch: [1][3300/40998] Data 0.553 (0.553) Elapsed 46m 58s (remain 536m 29s) Loss: 0.2211(0.3863) Grad: 2.2181 Epoch: [1][3320/40998] Data 0.552 (0.553) Elapsed 47m 15s (remain 536m 12s) Loss: 0.3167(0.3857) Grad: 3.7030 Epoch: [1][3340/40998] Data 0.553 (0.553) Elapsed 47m 32s (remain 535m 54s) Loss: 0.2307(0.3852) Grad: 3.2164 Epoch: [1][3360/40998] Data 0.553 (0.553) Elapsed 47m 49s (remain 535m 37s) Loss: 0.2331(0.3849) Grad: 1.6425 Epoch: [1][3380/40998] Data 0.553 (0.553) Elapsed 48m 6s (remain 535m 20s) Loss: 0.3989(0.3844) Grad: 3.6748 Epoch: [1][3400/40998] Data 0.553 (0.553) Elapsed 48m 24s (remain 535m 3s) Loss: 0.3708(0.3839) Grad: 6.4415 Epoch: [1][3420/40998] Data 0.553 (0.553) Elapsed 48m 41s (remain 534m 45s) Loss: 0.3647(0.3835) Grad: 3.6918 Epoch: [1][3440/40998] Data 0.552 (0.553) Elapsed 48m 58s (remain 534m 28s) Loss: 0.3599(0.3833) Grad: 4.7906 Epoch: [1][3460/40998] Data 0.552 (0.553) Elapsed 49m 15s (remain 534m 11s) Loss: 0.4861(0.3831) Grad: 4.1276 Epoch: [1][3480/40998] Data 0.553 (0.553) Elapsed 49m 32s (remain 533m 54s) Loss: 0.2374(0.3827) Grad: 2.7598 Epoch: [1][3500/40998] Data 0.553 (0.553) Elapsed 49m 49s (remain 533m 36s) Loss: 0.1992(0.3821) Grad: 3.7130 Epoch: [1][3520/40998] Data 0.552 (0.553) Elapsed 50m 6s (remain 533m 19s) Loss: 0.3234(0.3818) Grad: 2.9725 Epoch: [1][3540/40998] Data 0.553 (0.553) Elapsed 50m 23s (remain 533m 2s) Loss: 0.2880(0.3813) Grad: 3.5453 Epoch: [1][3560/40998] Data 0.552 (0.553) Elapsed 50m 40s (remain 532m 45s) Loss: 0.2963(0.3808) Grad: 2.6271 Epoch: [1][3580/40998] Data 0.554 (0.553) Elapsed 50m 57s (remain 532m 27s) Loss: 0.1611(0.3803) Grad: 2.6611 Epoch: [1][3600/40998] Data 0.552 (0.553) Elapsed 51m 14s (remain 532m 10s) Loss: 0.3852(0.3798) Grad: 4.4395 Epoch: [1][3620/40998] Data 0.552 (0.553) Elapsed 51m 31s (remain 531m 53s) Loss: 0.3651(0.3790) Grad: 3.9010 Epoch: [1][3640/40998] Data 0.552 (0.553) Elapsed 51m 48s (remain 531m 36s) Loss: 0.2293(0.3784) Grad: 2.5971 Epoch: [1][3660/40998] Data 0.553 (0.553) Elapsed 52m 5s (remain 531m 19s) Loss: 0.2678(0.3779) Grad: 2.4506 Epoch: [1][3680/40998] Data 0.553 (0.553) Elapsed 52m 22s (remain 531m 1s) Loss: 0.1551(0.3773) Grad: 2.0285 Epoch: [1][3700/40998] Data 0.552 (0.553) Elapsed 52m 39s (remain 530m 44s) Loss: 0.2663(0.3767) Grad: 3.4638 Epoch: [1][3720/40998] Data 0.552 (0.553) Elapsed 52m 57s (remain 530m 27s) Loss: 0.2276(0.3761) Grad: 3.0073 Epoch: [1][3740/40998] Data 0.552 (0.553) Elapsed 53m 14s (remain 530m 10s) Loss: 0.2513(0.3755) Grad: 3.2773 Epoch: [1][3760/40998] Data 0.552 (0.553) Elapsed 53m 31s (remain 529m 52s) Loss: 0.2192(0.3753) Grad: 1.9776 Epoch: [1][3780/40998] Data 0.550 (0.553) Elapsed 53m 48s (remain 529m 35s) Loss: 0.2868(0.3750) Grad: 3.6985 Epoch: [1][3800/40998] Data 0.552 (0.553) Elapsed 54m 5s (remain 529m 18s) Loss: 0.1692(0.3746) Grad: 2.2160 Epoch: [1][3820/40998] Data 0.553 (0.553) Elapsed 54m 22s (remain 529m 1s) Loss: 0.1742(0.3741) Grad: 1.9484 Epoch: [1][3840/40998] Data 0.553 (0.553) Elapsed 54m 39s (remain 528m 44s) Loss: 0.3369(0.3737) Grad: 3.7118 Epoch: [1][3860/40998] Data 0.553 (0.553) Elapsed 54m 56s (remain 528m 26s) Loss: 0.3194(0.3733) Grad: 4.3723 Epoch: [1][3880/40998] Data 0.553 (0.553) Elapsed 55m 13s (remain 528m 9s) Loss: 0.2712(0.3727) Grad: 3.4874 Epoch: [1][3900/40998] Data 0.553 (0.553) Elapsed 55m 30s (remain 527m 52s) Loss: 0.2289(0.3724) Grad: 2.4119 Epoch: [1][3920/40998] Data 0.553 (0.553) Elapsed 55m 47s (remain 527m 35s) Loss: 0.2685(0.3719) Grad: 2.0318 Epoch: [1][3940/40998] Data 0.553 (0.553) Elapsed 56m 4s (remain 527m 18s) Loss: 0.1566(0.3714) Grad: 2.2899 Epoch: [1][3960/40998] Data 0.552 (0.553) Elapsed 56m 21s (remain 527m 0s) Loss: 0.2567(0.3708) Grad: 3.0348 Epoch: [1][3980/40998] Data 0.552 (0.553) Elapsed 56m 38s (remain 526m 43s) Loss: 0.5365(0.3707) Grad: 3.9728 Epoch: [1][4000/40998] Data 0.553 (0.553) Elapsed 56m 55s (remain 526m 26s) Loss: 0.2913(0.3704) Grad: 2.5863 Epoch: [1][4020/40998] Data 0.553 (0.553) Elapsed 57m 12s (remain 526m 9s) Loss: 0.3015(0.3701) Grad: 3.8422 Epoch: [1][4040/40998] Data 0.553 (0.553) Elapsed 57m 30s (remain 525m 52s) Loss: 0.2245(0.3698) Grad: 3.5343 Epoch: [1][4060/40998] Data 0.552 (0.553) Elapsed 57m 47s (remain 525m 34s) Loss: 0.4113(0.3693) Grad: 3.2621 Epoch: [1][4080/40998] Data 0.553 (0.553) Elapsed 58m 4s (remain 525m 17s) Loss: 0.1720(0.3691) Grad: 2.8943 Epoch: [1][4100/40998] Data 0.553 (0.553) Elapsed 58m 21s (remain 525m 0s) Loss: 0.3074(0.3688) Grad: 3.4855 Epoch: [1][4120/40998] Data 0.552 (0.553) Elapsed 58m 38s (remain 524m 43s) Loss: 0.4183(0.3683) Grad: 3.4040 Epoch: [1][4140/40998] Data 0.553 (0.553) Elapsed 58m 55s (remain 524m 26s) Loss: 0.6403(0.3680) Grad: 7.0657 Epoch: [1][4160/40998] Data 0.553 (0.553) Elapsed 59m 12s (remain 524m 8s) Loss: 0.1802(0.3676) Grad: 2.4993 Epoch: [1][4180/40998] Data 0.553 (0.553) Elapsed 59m 29s (remain 523m 51s) Loss: 0.5426(0.3673) Grad: 5.0276 Epoch: [1][4200/40998] Data 0.553 (0.553) Elapsed 59m 46s (remain 523m 34s) Loss: 0.2038(0.3667) Grad: 2.8629 Epoch: [1][4220/40998] Data 0.553 (0.553) Elapsed 60m 3s (remain 523m 17s) Loss: 0.2117(0.3663) Grad: 2.2533 Epoch: [1][4240/40998] Data 0.553 (0.553) Elapsed 60m 20s (remain 523m 0s) Loss: 0.3017(0.3660) Grad: 4.0852 Epoch: [1][4260/40998] Data 0.553 (0.553) Elapsed 60m 37s (remain 522m 42s) Loss: 0.1612(0.3655) Grad: 2.9645 Epoch: [1][4280/40998] Data 0.553 (0.553) Elapsed 60m 54s (remain 522m 25s) Loss: 0.8159(0.3652) Grad: 12.0382 Epoch: [1][4300/40998] Data 0.553 (0.553) Elapsed 61m 11s (remain 522m 8s) Loss: 0.2168(0.3646) Grad: 2.7309 Epoch: [1][4320/40998] Data 0.553 (0.553) Elapsed 61m 28s (remain 521m 51s) Loss: 0.1089(0.3644) Grad: 1.3551 Epoch: [1][4340/40998] Data 0.552 (0.553) Elapsed 61m 45s (remain 521m 34s) Loss: 0.2641(0.3640) Grad: 2.3534 Epoch: [1][4360/40998] Data 0.552 (0.553) Elapsed 62m 2s (remain 521m 17s) Loss: 0.0522(0.3634) Grad: 0.7182 Epoch: [1][4380/40998] Data 0.552 (0.553) Elapsed 62m 20s (remain 520m 59s) Loss: 0.4386(0.3631) Grad: 3.6599 Epoch: [1][4400/40998] Data 0.553 (0.553) Elapsed 62m 37s (remain 520m 42s) Loss: 0.3075(0.3630) Grad: 3.1810 Epoch: [1][4420/40998] Data 0.552 (0.553) Elapsed 62m 54s (remain 520m 25s) Loss: 0.1202(0.3624) Grad: 1.5744 Epoch: [1][4440/40998] Data 0.553 (0.553) Elapsed 63m 11s (remain 520m 8s) Loss: 0.1457(0.3618) Grad: 2.7289 Epoch: [1][4460/40998] Data 0.553 (0.553) Elapsed 63m 28s (remain 519m 51s) Loss: 0.1105(0.3614) Grad: 2.1635 Epoch: [1][4480/40998] Data 0.553 (0.553) Elapsed 63m 45s (remain 519m 33s) Loss: 0.5812(0.3614) Grad: 6.0290 Epoch: [1][4500/40998] Data 0.553 (0.553) Elapsed 64m 2s (remain 519m 16s) Loss: 0.1963(0.3610) Grad: 2.1509 Epoch: [1][4520/40998] Data 0.553 (0.553) Elapsed 64m 19s (remain 518m 59s) Loss: 0.1424(0.3606) Grad: 1.7593 Epoch: [1][4540/40998] Data 0.553 (0.553) Elapsed 64m 36s (remain 518m 42s) Loss: 0.2137(0.3603) Grad: 4.0668 Epoch: [1][4560/40998] Data 0.553 (0.553) Elapsed 64m 53s (remain 518m 25s) Loss: 0.4333(0.3600) Grad: 3.4192 Epoch: [1][4580/40998] Data 0.553 (0.553) Elapsed 65m 10s (remain 518m 8s) Loss: 0.2566(0.3598) Grad: 3.0899 Epoch: [1][4600/40998] Data 0.553 (0.553) Elapsed 65m 27s (remain 517m 50s) Loss: 0.2513(0.3593) Grad: 3.2703 Epoch: [1][4620/40998] Data 0.553 (0.553) Elapsed 65m 44s (remain 517m 33s) Loss: 0.0455(0.3589) Grad: 0.7746 Epoch: [1][4640/40998] Data 0.552 (0.553) Elapsed 66m 1s (remain 517m 16s) Loss: 0.1582(0.3585) Grad: 2.8523 Epoch: [1][4660/40998] Data 0.552 (0.553) Elapsed 66m 18s (remain 516m 59s) Loss: 0.4539(0.3580) Grad: 3.6360 Epoch: [1][4680/40998] Data 0.552 (0.553) Elapsed 66m 35s (remain 516m 42s) Loss: 0.4088(0.3579) Grad: 3.6551 Epoch: [1][4700/40998] Data 0.553 (0.553) Elapsed 66m 53s (remain 516m 24s) Loss: 0.2001(0.3578) Grad: 1.6809 Epoch: [1][4720/40998] Data 0.553 (0.553) Elapsed 67m 10s (remain 516m 7s) Loss: 0.1451(0.3574) Grad: 1.8845 Epoch: [1][4740/40998] Data 0.552 (0.553) Elapsed 67m 27s (remain 515m 50s) Loss: 0.1173(0.3571) Grad: 1.7397 Epoch: [1][4760/40998] Data 0.553 (0.553) Elapsed 67m 44s (remain 515m 33s) Loss: 0.1500(0.3568) Grad: 1.9881 Epoch: [1][4780/40998] Data 0.553 (0.553) Elapsed 68m 1s (remain 515m 16s) Loss: 0.1083(0.3563) Grad: 1.5966 Epoch: [1][4800/40998] Data 0.553 (0.553) Elapsed 68m 18s (remain 514m 59s) Loss: 1.3131(0.3562) Grad: 12.6580 Epoch: [1][4820/40998] Data 0.553 (0.553) Elapsed 68m 35s (remain 514m 42s) Loss: 0.1663(0.3557) Grad: 1.6973 Epoch: [1][4840/40998] Data 0.553 (0.553) Elapsed 68m 52s (remain 514m 24s) Loss: 0.1008(0.3552) Grad: 1.4107 Epoch: [1][4860/40998] Data 0.553 (0.553) Elapsed 69m 9s (remain 514m 7s) Loss: 0.6115(0.3548) Grad: 5.3901 Epoch: [1][4880/40998] Data 0.553 (0.553) Elapsed 69m 26s (remain 513m 50s) Loss: 0.2030(0.3543) Grad: 2.4576 Epoch: [1][4900/40998] Data 0.553 (0.553) Elapsed 69m 43s (remain 513m 33s) Loss: 0.1836(0.3539) Grad: 2.8831 Epoch: [1][4920/40998] Data 0.553 (0.553) Elapsed 70m 0s (remain 513m 16s) Loss: 0.3837(0.3535) Grad: 3.9599 Epoch: [1][4940/40998] Data 0.553 (0.553) Elapsed 70m 17s (remain 512m 59s) Loss: 0.3085(0.3532) Grad: 5.1727 Epoch: [1][4960/40998] Data 0.553 (0.553) Elapsed 70m 34s (remain 512m 41s) Loss: 0.4196(0.3530) Grad: 4.3328 Epoch: [1][4980/40998] Data 0.553 (0.553) Elapsed 70m 51s (remain 512m 24s) Loss: 0.1627(0.3526) Grad: 1.9059 Epoch: [1][5000/40998] Data 0.553 (0.553) Elapsed 71m 8s (remain 512m 7s) Loss: 0.1264(0.3524) Grad: 1.5472 Epoch: [1][5020/40998] Data 0.553 (0.553) Elapsed 71m 26s (remain 511m 50s) Loss: 0.2456(0.3521) Grad: 4.0302 Epoch: [1][5040/40998] Data 0.552 (0.553) Elapsed 71m 43s (remain 511m 33s) Loss: 0.2323(0.3518) Grad: 2.7378 Epoch: [1][5060/40998] Data 0.552 (0.553) Elapsed 72m 0s (remain 511m 16s) Loss: 0.1625(0.3513) Grad: 3.9250 Epoch: [1][5080/40998] Data 0.552 (0.553) Elapsed 72m 17s (remain 510m 59s) Loss: 0.1275(0.3509) Grad: 1.5964 Epoch: [1][5100/40998] Data 0.553 (0.553) Elapsed 72m 34s (remain 510m 41s) Loss: 0.1923(0.3506) Grad: 2.5608 Epoch: [1][5120/40998] Data 0.553 (0.553) Elapsed 72m 51s (remain 510m 24s) Loss: 0.3269(0.3503) Grad: 3.5254 Epoch: [1][5140/40998] Data 0.553 (0.553) Elapsed 73m 8s (remain 510m 7s) Loss: 0.2733(0.3500) Grad: 2.8224 Epoch: [1][5160/40998] Data 0.551 (0.553) Elapsed 73m 25s (remain 509m 50s) Loss: 0.1683(0.3497) Grad: 1.6668 Epoch: [1][5180/40998] Data 0.553 (0.553) Elapsed 73m 42s (remain 509m 33s) Loss: 0.3279(0.3493) Grad: 2.4237 Epoch: [1][5200/40998] Data 0.553 (0.553) Elapsed 73m 59s (remain 509m 16s) Loss: 0.1168(0.3491) Grad: 1.4203 Epoch: [1][5220/40998] Data 0.553 (0.553) Elapsed 74m 16s (remain 508m 58s) Loss: 0.2589(0.3488) Grad: 3.2881 Epoch: [1][5240/40998] Data 0.553 (0.553) Elapsed 74m 33s (remain 508m 41s) Loss: 0.1274(0.3483) Grad: 2.2733 Epoch: [1][5260/40998] Data 0.553 (0.553) Elapsed 74m 50s (remain 508m 24s) Loss: 0.2523(0.3481) Grad: 3.0727 Epoch: [1][5280/40998] Data 0.553 (0.553) Elapsed 75m 7s (remain 508m 7s) Loss: 0.3305(0.3479) Grad: 2.9528 Epoch: [1][5300/40998] Data 0.552 (0.553) Elapsed 75m 24s (remain 507m 50s) Loss: 0.2258(0.3474) Grad: 2.5228 Epoch: [1][5320/40998] Data 0.553 (0.553) Elapsed 75m 41s (remain 507m 33s) Loss: 0.3892(0.3472) Grad: 3.8145 Epoch: [1][5340/40998] Data 0.552 (0.553) Elapsed 75m 58s (remain 507m 16s) Loss: 0.3844(0.3468) Grad: 5.6329 Epoch: [1][5360/40998] Data 0.553 (0.553) Elapsed 76m 16s (remain 506m 58s) Loss: 0.7671(0.3466) Grad: 9.5603 Epoch: [1][5380/40998] Data 0.552 (0.553) Elapsed 76m 33s (remain 506m 41s) Loss: 0.3167(0.3464) Grad: 5.0064 Epoch: [1][5400/40998] Data 0.553 (0.553) Elapsed 76m 50s (remain 506m 24s) Loss: 0.2259(0.3461) Grad: 3.2661 Epoch: [1][5420/40998] Data 0.553 (0.553) Elapsed 77m 7s (remain 506m 7s) Loss: 0.4502(0.3457) Grad: 3.3215 Epoch: [1][5440/40998] Data 0.553 (0.553) Elapsed 77m 24s (remain 505m 50s) Loss: 0.2984(0.3454) Grad: 2.8672 Epoch: [1][5460/40998] Data 0.553 (0.553) Elapsed 77m 41s (remain 505m 33s) Loss: 0.3277(0.3454) Grad: 2.9040 Epoch: [1][5480/40998] Data 0.553 (0.553) Elapsed 77m 58s (remain 505m 16s) Loss: 0.1802(0.3451) Grad: 1.5881 Epoch: [1][5500/40998] Data 0.553 (0.553) Elapsed 78m 15s (remain 504m 59s) Loss: 0.6011(0.3450) Grad: 5.3194 Epoch: [1][5520/40998] Data 0.553 (0.553) Elapsed 78m 32s (remain 504m 41s) Loss: 0.0747(0.3447) Grad: 1.0549 Epoch: [1][5540/40998] Data 0.553 (0.553) Elapsed 78m 49s (remain 504m 24s) Loss: 0.1258(0.3444) Grad: 2.4118 Epoch: [1][5560/40998] Data 0.553 (0.553) Elapsed 79m 6s (remain 504m 7s) Loss: 0.2437(0.3441) Grad: 2.6214 Epoch: [1][5580/40998] Data 0.553 (0.553) Elapsed 79m 23s (remain 503m 50s) Loss: 0.3223(0.3438) Grad: 2.8051 Epoch: [1][5600/40998] Data 0.553 (0.553) Elapsed 79m 40s (remain 503m 33s) Loss: 0.2120(0.3434) Grad: 4.0831 Epoch: [1][5620/40998] Data 0.551 (0.553) Elapsed 79m 57s (remain 503m 16s) Loss: 0.2394(0.3430) Grad: 3.0675 Epoch: [1][5640/40998] Data 0.552 (0.553) Elapsed 80m 14s (remain 502m 59s) Loss: 0.0989(0.3429) Grad: 0.9570 Epoch: [1][5660/40998] Data 0.553 (0.553) Elapsed 80m 31s (remain 502m 41s) Loss: 0.1940(0.3425) Grad: 3.8546 Epoch: [1][5680/40998] Data 0.553 (0.553) Elapsed 80m 49s (remain 502m 24s) Loss: 0.1959(0.3422) Grad: 3.8876 Epoch: [1][5700/40998] Data 0.553 (0.553) Elapsed 81m 6s (remain 502m 7s) Loss: 0.4826(0.3421) Grad: 5.0981 Epoch: [1][5720/40998] Data 0.553 (0.553) Elapsed 81m 23s (remain 501m 50s) Loss: 0.2601(0.3420) Grad: 2.8846 Epoch: [1][5740/40998] Data 0.553 (0.553) Elapsed 81m 40s (remain 501m 33s) Loss: 0.3377(0.3417) Grad: 3.2157 Epoch: [1][5760/40998] Data 0.553 (0.553) Elapsed 81m 57s (remain 501m 16s) Loss: 0.1182(0.3413) Grad: 1.0749 Epoch: [1][5780/40998] Data 0.553 (0.553) Elapsed 82m 14s (remain 500m 59s) Loss: 0.2757(0.3411) Grad: 4.0697 Epoch: [1][5800/40998] Data 0.553 (0.553) Elapsed 82m 31s (remain 500m 42s) Loss: 0.2644(0.3408) Grad: 3.9125 Epoch: [1][5820/40998] Data 0.553 (0.553) Elapsed 82m 48s (remain 500m 24s) Loss: 0.1941(0.3405) Grad: 2.2366 Epoch: [1][5840/40998] Data 0.553 (0.553) Elapsed 83m 5s (remain 500m 7s) Loss: 0.1492(0.3403) Grad: 2.0205 Epoch: [1][5860/40998] Data 0.553 (0.553) Elapsed 83m 22s (remain 499m 50s) Loss: 0.1622(0.3401) Grad: 2.2099 Epoch: [1][5880/40998] Data 0.552 (0.553) Elapsed 83m 39s (remain 499m 33s) Loss: 0.1354(0.3399) Grad: 1.9441 Epoch: [1][5900/40998] Data 0.553 (0.553) Elapsed 83m 56s (remain 499m 16s) Loss: 0.2453(0.3396) Grad: 3.3732 Epoch: [1][5920/40998] Data 0.552 (0.553) Elapsed 84m 13s (remain 498m 59s) Loss: 0.4683(0.3394) Grad: 3.3761 Epoch: [1][5940/40998] Data 0.553 (0.553) Elapsed 84m 30s (remain 498m 42s) Loss: 0.2050(0.3392) Grad: 2.1987 Epoch: [1][5960/40998] Data 0.553 (0.553) Elapsed 84m 47s (remain 498m 24s) Loss: 0.1026(0.3389) Grad: 1.4665 Epoch: [1][5980/40998] Data 0.553 (0.553) Elapsed 85m 4s (remain 498m 7s) Loss: 0.5104(0.3385) Grad: 4.8246 Epoch: [1][6000/40998] Data 0.553 (0.553) Elapsed 85m 21s (remain 497m 50s) Loss: 0.6427(0.3383) Grad: 5.2195 Epoch: [1][6020/40998] Data 0.553 (0.553) Elapsed 85m 39s (remain 497m 33s) Loss: 0.3244(0.3380) Grad: 2.8280 Epoch: [1][6040/40998] Data 0.553 (0.553) Elapsed 85m 56s (remain 497m 16s) Loss: 0.2403(0.3376) Grad: 3.4935 Epoch: [1][6060/40998] Data 0.552 (0.553) Elapsed 86m 13s (remain 496m 59s) Loss: 0.1683(0.3372) Grad: 2.1730 Epoch: [1][6080/40998] Data 0.553 (0.553) Elapsed 86m 30s (remain 496m 42s) Loss: 0.1496(0.3369) Grad: 2.8679 Epoch: [1][6100/40998] Data 0.553 (0.553) Elapsed 86m 47s (remain 496m 25s) Loss: 0.1647(0.3366) Grad: 2.5862 Epoch: [1][6120/40998] Data 0.553 (0.553) Elapsed 87m 4s (remain 496m 7s) Loss: 0.2002(0.3365) Grad: 1.6921 Epoch: [1][6140/40998] Data 0.553 (0.553) Elapsed 87m 21s (remain 495m 50s) Loss: 0.1768(0.3361) Grad: 2.3348 Epoch: [1][6160/40998] Data 0.554 (0.553) Elapsed 87m 38s (remain 495m 33s) Loss: 0.1564(0.3358) Grad: 2.0361 Epoch: [1][6180/40998] Data 0.553 (0.553) Elapsed 87m 55s (remain 495m 16s) Loss: 0.1655(0.3355) Grad: 3.4164 Epoch: [1][6200/40998] Data 0.552 (0.553) Elapsed 88m 12s (remain 494m 59s) Loss: 0.1463(0.3353) Grad: 2.1613 Epoch: [1][6220/40998] Data 0.553 (0.553) Elapsed 88m 29s (remain 494m 42s) Loss: 0.1725(0.3351) Grad: 2.3651 Epoch: [1][6240/40998] Data 0.553 (0.553) Elapsed 88m 46s (remain 494m 25s) Loss: 0.1361(0.3348) Grad: 2.0024 Epoch: [1][6260/40998] Data 0.553 (0.553) Elapsed 89m 3s (remain 494m 8s) Loss: 0.4062(0.3346) Grad: 4.4391 Epoch: [1][6280/40998] Data 0.552 (0.553) Elapsed 89m 20s (remain 493m 51s) Loss: 0.0795(0.3343) Grad: 1.3309 Epoch: [1][6300/40998] Data 0.552 (0.553) Elapsed 89m 37s (remain 493m 33s) Loss: 0.2992(0.3340) Grad: 4.3284 Epoch: [1][6320/40998] Data 0.553 (0.553) Elapsed 89m 54s (remain 493m 16s) Loss: 0.5671(0.3336) Grad: 4.3166 Epoch: [1][6340/40998] Data 0.553 (0.553) Elapsed 90m 12s (remain 492m 59s) Loss: 0.1979(0.3333) Grad: 2.2547 Epoch: [1][6360/40998] Data 0.552 (0.553) Elapsed 90m 29s (remain 492m 42s) Loss: 0.5255(0.3332) Grad: 4.7358 Epoch: [1][6380/40998] Data 0.553 (0.553) Elapsed 90m 46s (remain 492m 25s) Loss: 0.2170(0.3329) Grad: 3.2626 Epoch: [1][6400/40998] Data 0.552 (0.553) Elapsed 91m 3s (remain 492m 8s) Loss: 0.1909(0.3326) Grad: 2.6825 Epoch: [1][6420/40998] Data 0.553 (0.553) Elapsed 91m 20s (remain 491m 51s) Loss: 0.2815(0.3323) Grad: 2.9678 Epoch: [1][6440/40998] Data 0.553 (0.553) Elapsed 91m 37s (remain 491m 34s) Loss: 0.4191(0.3321) Grad: 3.4090 Epoch: [1][6460/40998] Data 0.553 (0.553) Elapsed 91m 54s (remain 491m 17s) Loss: 0.1470(0.3319) Grad: 2.0783 Epoch: [1][6480/40998] Data 0.553 (0.553) Elapsed 92m 11s (remain 490m 59s) Loss: 0.2028(0.3316) Grad: 2.5423 Epoch: [1][6500/40998] Data 0.553 (0.553) Elapsed 92m 28s (remain 490m 42s) Loss: 0.4085(0.3314) Grad: 3.5672 Epoch: [1][6520/40998] Data 0.553 (0.553) Elapsed 92m 45s (remain 490m 25s) Loss: 0.3225(0.3311) Grad: 2.9536 Epoch: [1][6540/40998] Data 0.553 (0.553) Elapsed 93m 2s (remain 490m 8s) Loss: 0.1725(0.3308) Grad: 2.4769 Epoch: [1][6560/40998] Data 0.552 (0.553) Elapsed 93m 19s (remain 489m 51s) Loss: 0.2230(0.3307) Grad: 3.6438 Epoch: [1][6580/40998] Data 0.553 (0.553) Elapsed 93m 36s (remain 489m 34s) Loss: 0.4074(0.3306) Grad: 3.9932 Epoch: [1][6600/40998] Data 0.552 (0.553) Elapsed 93m 53s (remain 489m 17s) Loss: 0.3403(0.3304) Grad: 2.7387 Epoch: [1][6620/40998] Data 0.552 (0.553) Elapsed 94m 10s (remain 489m 0s) Loss: 0.1046(0.3301) Grad: 5.3190 Epoch: [1][6640/40998] Data 0.552 (0.553) Elapsed 94m 27s (remain 488m 43s) Loss: 0.1403(0.3298) Grad: 2.6859 Epoch: [1][6660/40998] Data 0.552 (0.553) Elapsed 94m 45s (remain 488m 25s) Loss: 0.2339(0.3295) Grad: 4.2010 Epoch: [1][6680/40998] Data 0.553 (0.553) Elapsed 95m 2s (remain 488m 8s) Loss: 0.2946(0.3292) Grad: 3.1304 Epoch: [1][6700/40998] Data 0.553 (0.553) Elapsed 95m 19s (remain 487m 51s) Loss: 0.4629(0.3289) Grad: 3.6245 Epoch: [1][6720/40998] Data 0.552 (0.553) Elapsed 95m 36s (remain 487m 34s) Loss: 0.2700(0.3287) Grad: 2.6775 Epoch: [1][6740/40998] Data 0.552 (0.553) Elapsed 95m 53s (remain 487m 17s) Loss: 0.4529(0.3283) Grad: 7.8528 Epoch: [1][6760/40998] Data 0.553 (0.553) Elapsed 96m 10s (remain 487m 0s) Loss: 0.2479(0.3279) Grad: 2.9501 Epoch: [1][6780/40998] Data 0.552 (0.553) Elapsed 96m 27s (remain 486m 43s) Loss: 0.3409(0.3277) Grad: 3.2633 Epoch: [1][6800/40998] Data 0.553 (0.553) Elapsed 96m 44s (remain 486m 26s) Loss: 0.1826(0.3273) Grad: 2.0594 Epoch: [1][6820/40998] Data 0.552 (0.553) Elapsed 97m 1s (remain 486m 9s) Loss: 0.3317(0.3269) Grad: 3.5999 Epoch: [1][6840/40998] Data 0.553 (0.553) Elapsed 97m 18s (remain 485m 52s) Loss: 0.5959(0.3267) Grad: 7.9406 Epoch: [1][6860/40998] Data 0.553 (0.553) Elapsed 97m 35s (remain 485m 34s) Loss: 0.1806(0.3265) Grad: 1.9975 Epoch: [1][6880/40998] Data 0.552 (0.553) Elapsed 97m 52s (remain 485m 17s) Loss: 0.3642(0.3263) Grad: 2.6956 Epoch: [1][6900/40998] Data 0.553 (0.553) Elapsed 98m 9s (remain 485m 0s) Loss: 0.3441(0.3261) Grad: 3.6517 Epoch: [1][6920/40998] Data 0.552 (0.553) Elapsed 98m 26s (remain 484m 43s) Loss: 0.4589(0.3257) Grad: 4.8953 Epoch: [1][6940/40998] Data 0.553 (0.553) Elapsed 98m 43s (remain 484m 26s) Loss: 0.4253(0.3256) Grad: 3.7347 Epoch: [1][6960/40998] Data 0.553 (0.553) Elapsed 99m 0s (remain 484m 9s) Loss: 0.1996(0.3254) Grad: 1.4495 Epoch: [1][6980/40998] Data 0.553 (0.553) Elapsed 99m 18s (remain 483m 52s) Loss: 0.1665(0.3251) Grad: 2.1453 Epoch: [1][7000/40998] Data 0.553 (0.553) Elapsed 99m 35s (remain 483m 35s) Loss: 0.2467(0.3247) Grad: 3.0714 Epoch: [1][7020/40998] Data 0.553 (0.553) Elapsed 99m 52s (remain 483m 18s) Loss: 0.1989(0.3244) Grad: 1.5400 Epoch: [1][7040/40998] Data 0.553 (0.553) Elapsed 100m 9s (remain 483m 1s) Loss: 0.2887(0.3242) Grad: 4.0853 Epoch: [1][7060/40998] Data 0.553 (0.553) Elapsed 100m 26s (remain 482m 44s) Loss: 0.5693(0.3239) Grad: 6.2781 Epoch: [1][7080/40998] Data 0.552 (0.553) Elapsed 100m 43s (remain 482m 26s) Loss: 0.2596(0.3237) Grad: 4.4527 Epoch: [1][7100/40998] Data 0.553 (0.553) Elapsed 101m 0s (remain 482m 9s) Loss: 0.0502(0.3233) Grad: 1.4079 Epoch: [1][7120/40998] Data 0.552 (0.553) Elapsed 101m 17s (remain 481m 52s) Loss: 0.3165(0.3231) Grad: 2.9040 Epoch: [1][7140/40998] Data 0.553 (0.553) Elapsed 101m 34s (remain 481m 35s) Loss: 0.2737(0.3229) Grad: 3.4684 Epoch: [1][7160/40998] Data 0.553 (0.553) Elapsed 101m 51s (remain 481m 18s) Loss: 0.2868(0.3228) Grad: 2.9093 Epoch: [1][7180/40998] Data 0.553 (0.553) Elapsed 102m 8s (remain 481m 1s) Loss: 0.2075(0.3225) Grad: 2.4910 Epoch: [1][7200/40998] Data 0.553 (0.553) Elapsed 102m 25s (remain 480m 44s) Loss: 0.0768(0.3221) Grad: 1.4086 Epoch: [1][7220/40998] Data 0.553 (0.553) Elapsed 102m 42s (remain 480m 27s) Loss: 0.1626(0.3219) Grad: 2.2545 Epoch: [1][7240/40998] Data 0.553 (0.553) Elapsed 102m 59s (remain 480m 10s) Loss: 0.3435(0.3217) Grad: 2.3138 Epoch: [1][7260/40998] Data 0.553 (0.553) Elapsed 103m 16s (remain 479m 53s) Loss: 0.1875(0.3214) Grad: 2.7984 Epoch: [1][7280/40998] Data 0.552 (0.553) Elapsed 103m 34s (remain 479m 36s) Loss: 0.0248(0.3213) Grad: 0.5163 Epoch: [1][7300/40998] Data 0.552 (0.553) Elapsed 103m 51s (remain 479m 18s) Loss: 0.2297(0.3213) Grad: 2.2662 Epoch: [1][7320/40998] Data 0.553 (0.553) Elapsed 104m 8s (remain 479m 1s) Loss: 0.4000(0.3212) Grad: 3.7396 Epoch: [1][7340/40998] Data 0.553 (0.553) Elapsed 104m 25s (remain 478m 44s) Loss: 0.3687(0.3209) Grad: 3.2830 Epoch: [1][7360/40998] Data 0.553 (0.553) Elapsed 104m 42s (remain 478m 27s) Loss: 0.1503(0.3207) Grad: 1.4308 Epoch: [1][7380/40998] Data 0.553 (0.553) Elapsed 104m 59s (remain 478m 10s) Loss: 0.1425(0.3207) Grad: 1.3802 Epoch: [1][7400/40998] Data 0.553 (0.553) Elapsed 105m 16s (remain 477m 53s) Loss: 0.3844(0.3205) Grad: 3.5767 Epoch: [1][7420/40998] Data 0.552 (0.553) Elapsed 105m 33s (remain 477m 36s) Loss: 0.0928(0.3203) Grad: 0.9478 Epoch: [1][7440/40998] Data 0.553 (0.553) Elapsed 105m 50s (remain 477m 19s) Loss: 0.3098(0.3200) Grad: 3.4717 Epoch: [1][7460/40998] Data 0.554 (0.553) Elapsed 106m 7s (remain 477m 2s) Loss: 0.3247(0.3198) Grad: 3.5050 Epoch: [1][7480/40998] Data 0.552 (0.553) Elapsed 106m 24s (remain 476m 45s) Loss: 0.1858(0.3195) Grad: 1.9472 Epoch: [1][7500/40998] Data 0.553 (0.553) Elapsed 106m 41s (remain 476m 28s) Loss: 0.3041(0.3193) Grad: 3.9544 Epoch: [1][7520/40998] Data 0.553 (0.553) Elapsed 106m 58s (remain 476m 10s) Loss: 0.1156(0.3191) Grad: 1.3310 Epoch: [1][7540/40998] Data 0.553 (0.553) Elapsed 107m 15s (remain 475m 53s) Loss: 0.1100(0.3189) Grad: 2.3017 Epoch: [1][7560/40998] Data 0.552 (0.553) Elapsed 107m 32s (remain 475m 36s) Loss: 0.1350(0.3187) Grad: 2.3109 Epoch: [1][7580/40998] Data 0.553 (0.553) Elapsed 107m 49s (remain 475m 19s) Loss: 0.1409(0.3184) Grad: 2.2654 Epoch: [1][7600/40998] Data 0.552 (0.553) Elapsed 108m 7s (remain 475m 2s) Loss: 0.3467(0.3181) Grad: 5.2347 Epoch: [1][7620/40998] Data 0.553 (0.553) Elapsed 108m 24s (remain 474m 45s) Loss: 0.2367(0.3179) Grad: 3.5917 Epoch: [1][7640/40998] Data 0.552 (0.553) Elapsed 108m 41s (remain 474m 28s) Loss: 0.1584(0.3176) Grad: 1.9672 Epoch: [1][7660/40998] Data 0.553 (0.553) Elapsed 108m 58s (remain 474m 11s) Loss: 0.1541(0.3174) Grad: 1.5600 Epoch: [1][7680/40998] Data 0.553 (0.553) Elapsed 109m 15s (remain 473m 54s) Loss: 0.1570(0.3171) Grad: 2.4373 Epoch: [1][7700/40998] Data 0.553 (0.553) Elapsed 109m 32s (remain 473m 37s) Loss: 0.4842(0.3172) Grad: 4.4670 Epoch: [1][7720/40998] Data 0.552 (0.553) Elapsed 109m 49s (remain 473m 20s) Loss: 0.1896(0.3169) Grad: 4.2196 Epoch: [1][7740/40998] Data 0.553 (0.553) Elapsed 110m 6s (remain 473m 2s) Loss: 0.1877(0.3167) Grad: 2.8570 Epoch: [1][7760/40998] Data 0.553 (0.553) Elapsed 110m 23s (remain 472m 45s) Loss: 0.2565(0.3164) Grad: 2.3546 Epoch: [1][7780/40998] Data 0.553 (0.553) Elapsed 110m 40s (remain 472m 28s) Loss: 0.1454(0.3162) Grad: 2.0119 Epoch: [1][7800/40998] Data 0.553 (0.553) Elapsed 110m 57s (remain 472m 11s) Loss: 0.1571(0.3160) Grad: 1.7658 Epoch: [1][7820/40998] Data 0.553 (0.553) Elapsed 111m 14s (remain 471m 54s) Loss: 0.4768(0.3158) Grad: 6.5107 Epoch: [1][7840/40998] Data 0.552 (0.553) Elapsed 111m 31s (remain 471m 37s) Loss: 0.1246(0.3155) Grad: 1.4379 Epoch: [1][7860/40998] Data 0.553 (0.553) Elapsed 111m 48s (remain 471m 20s) Loss: 0.2261(0.3154) Grad: 1.9103 Epoch: [1][7880/40998] Data 0.553 (0.553) Elapsed 112m 5s (remain 471m 3s) Loss: 0.2142(0.3152) Grad: 4.0401 Epoch: [1][7900/40998] Data 0.552 (0.553) Elapsed 112m 23s (remain 470m 46s) Loss: 0.2710(0.3150) Grad: 2.2426 Epoch: [1][7920/40998] Data 0.553 (0.553) Elapsed 112m 40s (remain 470m 29s) Loss: 0.4663(0.3147) Grad: 8.0442 Epoch: [1][7940/40998] Data 0.553 (0.553) Elapsed 112m 57s (remain 470m 12s) Loss: 0.2715(0.3146) Grad: 2.1494 Epoch: [1][7960/40998] Data 0.553 (0.553) Elapsed 113m 14s (remain 469m 54s) Loss: 0.1994(0.3144) Grad: 3.0286 Epoch: [1][7980/40998] Data 0.553 (0.553) Elapsed 113m 31s (remain 469m 37s) Loss: 0.3840(0.3142) Grad: 5.5330 Epoch: [1][8000/40998] Data 0.552 (0.553) Elapsed 113m 48s (remain 469m 20s) Loss: 0.5636(0.3139) Grad: 6.6573 Epoch: [1][8020/40998] Data 0.553 (0.553) Elapsed 114m 5s (remain 469m 3s) Loss: 0.2956(0.3138) Grad: 2.9597 Epoch: [1][8040/40998] Data 0.553 (0.553) Elapsed 114m 22s (remain 468m 46s) Loss: 0.3440(0.3136) Grad: 4.0119 Epoch: [1][8060/40998] Data 0.553 (0.553) Elapsed 114m 39s (remain 468m 29s) Loss: 0.2996(0.3133) Grad: 3.5992 Epoch: [1][8080/40998] Data 0.553 (0.553) Elapsed 114m 56s (remain 468m 14s) Loss: 0.3323(0.3130) Grad: 2.7511 Epoch: [1][8100/40998] Data 0.552 (0.553) Elapsed 115m 13s (remain 467m 55s) Loss: 0.1447(0.3129) Grad: 1.7840 Epoch: [1][8120/40998] Data 0.553 (0.553) Elapsed 115m 30s (remain 467m 38s) Loss: 0.0963(0.3127) Grad: 2.1877 Epoch: [1][8140/40998] Data 0.553 (0.553) Elapsed 115m 47s (remain 467m 21s) Loss: 0.2763(0.3124) Grad: 2.7835 Epoch: [1][8160/40998] Data 0.553 (0.553) Elapsed 116m 4s (remain 467m 4s) Loss: 0.1076(0.3123) Grad: 1.5792 Epoch: [1][8180/40998] Data 0.553 (0.553) Elapsed 116m 21s (remain 466m 47s) Loss: 0.0530(0.3121) Grad: 0.7544 Epoch: [1][8200/40998] Data 0.553 (0.553) Elapsed 116m 38s (remain 466m 29s) Loss: 0.3634(0.3120) Grad: 4.5098 Epoch: [1][8220/40998] Data 0.553 (0.553) Elapsed 116m 56s (remain 466m 12s) Loss: 0.0977(0.3118) Grad: 1.1528 Epoch: [1][8240/40998] Data 0.553 (0.553) Elapsed 117m 13s (remain 465m 55s) Loss: 0.2761(0.3115) Grad: 3.2182 Epoch: [1][8260/40998] Data 0.551 (0.553) Elapsed 117m 30s (remain 465m 38s) Loss: 0.1904(0.3112) Grad: 3.0004 Epoch: [1][8280/40998] Data 0.552 (0.553) Elapsed 117m 47s (remain 465m 21s) Loss: 0.3905(0.3112) Grad: 4.3826 Epoch: [1][8300/40998] Data 0.553 (0.553) Elapsed 118m 4s (remain 465m 4s) Loss: 0.2046(0.3110) Grad: 3.5094 Epoch: [1][8320/40998] Data 0.552 (0.553) Elapsed 118m 21s (remain 464m 47s) Loss: 0.1888(0.3107) Grad: 2.5930 Epoch: [1][8340/40998] Data 0.553 (0.553) Elapsed 118m 38s (remain 464m 30s) Loss: 0.1149(0.3103) Grad: 2.0985 Epoch: [1][8360/40998] Data 0.553 (0.553) Elapsed 118m 55s (remain 464m 13s) Loss: 0.0526(0.3102) Grad: 0.8573 Epoch: [1][8380/40998] Data 0.553 (0.553) Elapsed 119m 12s (remain 463m 56s) Loss: 0.1776(0.3100) Grad: 1.3333 Epoch: [1][8400/40998] Data 0.553 (0.553) Elapsed 119m 29s (remain 463m 39s) Loss: 0.2564(0.3098) Grad: 3.7807 Epoch: [1][8420/40998] Data 0.552 (0.553) Elapsed 119m 46s (remain 463m 21s) Loss: 0.0605(0.3095) Grad: 1.9135 Epoch: [1][8440/40998] Data 0.553 (0.553) Elapsed 120m 3s (remain 463m 4s) Loss: 0.0925(0.3092) Grad: 1.5798 Epoch: [1][8460/40998] Data 0.552 (0.553) Elapsed 120m 20s (remain 462m 47s) Loss: 0.2237(0.3090) Grad: 3.4717 Epoch: [1][8480/40998] Data 0.552 (0.553) Elapsed 120m 37s (remain 462m 30s) Loss: 0.1953(0.3086) Grad: 2.2369 Epoch: [1][8500/40998] Data 0.553 (0.553) Elapsed 120m 54s (remain 462m 13s) Loss: 0.2648(0.3084) Grad: 2.4650 Epoch: [1][8520/40998] Data 0.553 (0.553) Elapsed 121m 11s (remain 461m 56s) Loss: 0.1297(0.3082) Grad: 2.3939 Epoch: [1][8540/40998] Data 0.553 (0.553) Elapsed 121m 29s (remain 461m 39s) Loss: 0.2249(0.3080) Grad: 2.3639 Epoch: [1][8560/40998] Data 0.553 (0.553) Elapsed 121m 46s (remain 461m 22s) Loss: 0.1523(0.3079) Grad: 2.2518 Epoch: [1][8580/40998] Data 0.552 (0.553) Elapsed 122m 3s (remain 461m 5s) Loss: 0.1525(0.3076) Grad: 1.8975 Epoch: [1][8600/40998] Data 0.553 (0.553) Elapsed 122m 20s (remain 460m 48s) Loss: 0.5005(0.3076) Grad: 5.8201 Epoch: [1][8620/40998] Data 0.552 (0.553) Elapsed 122m 37s (remain 460m 31s) Loss: 0.2401(0.3074) Grad: 2.6190 Epoch: [1][8640/40998] Data 0.552 (0.553) Elapsed 122m 54s (remain 460m 13s) Loss: 0.0555(0.3072) Grad: 1.1835 Epoch: [1][8660/40998] Data 0.553 (0.553) Elapsed 123m 11s (remain 459m 56s) Loss: 0.1641(0.3070) Grad: 2.4011 Epoch: [1][8680/40998] Data 0.553 (0.553) Elapsed 123m 28s (remain 459m 39s) Loss: 0.2482(0.3067) Grad: 3.2998 Epoch: [1][8700/40998] Data 0.553 (0.553) Elapsed 123m 45s (remain 459m 22s) Loss: 0.1856(0.3065) Grad: 2.1897 Epoch: [1][8720/40998] Data 0.553 (0.553) Elapsed 124m 2s (remain 459m 5s) Loss: 0.1369(0.3063) Grad: 2.0961 Epoch: [1][8740/40998] Data 0.553 (0.553) Elapsed 124m 19s (remain 458m 48s) Loss: 0.3175(0.3061) Grad: 3.3475 Epoch: [1][8760/40998] Data 0.553 (0.553) Elapsed 124m 36s (remain 458m 31s) Loss: 0.1796(0.3061) Grad: 1.3738 Epoch: [1][8780/40998] Data 0.552 (0.553) Elapsed 124m 53s (remain 458m 14s) Loss: 0.3632(0.3060) Grad: 4.6801 Epoch: [1][8800/40998] Data 0.552 (0.553) Elapsed 125m 10s (remain 457m 57s) Loss: 0.1331(0.3057) Grad: 2.1122 Epoch: [1][8820/40998] Data 0.553 (0.553) Elapsed 125m 27s (remain 457m 40s) Loss: 0.3033(0.3055) Grad: 4.5069 Epoch: [1][8840/40998] Data 0.553 (0.553) Elapsed 125m 44s (remain 457m 23s) Loss: 0.1644(0.3052) Grad: 2.6688 Epoch: [1][8860/40998] Data 0.553 (0.553) Elapsed 126m 2s (remain 457m 5s) Loss: 0.2470(0.3050) Grad: 3.8769 Epoch: [1][8880/40998] Data 0.553 (0.553) Elapsed 126m 19s (remain 456m 48s) Loss: 0.1042(0.3049) Grad: 1.5973 Epoch: [1][8900/40998] Data 0.553 (0.553) Elapsed 126m 36s (remain 456m 31s) Loss: 0.0957(0.3048) Grad: 1.4324 Epoch: [1][8920/40998] Data 0.553 (0.553) Elapsed 126m 53s (remain 456m 14s) Loss: 0.0964(0.3046) Grad: 1.3578 Epoch: [1][8940/40998] Data 0.553 (0.553) Elapsed 127m 10s (remain 455m 57s) Loss: 0.0837(0.3043) Grad: 1.1818 Epoch: [1][8960/40998] Data 0.553 (0.553) Elapsed 127m 27s (remain 455m 40s) Loss: 0.3886(0.3042) Grad: 3.1877 Epoch: [1][8980/40998] Data 0.553 (0.553) Elapsed 127m 44s (remain 455m 23s) Loss: 0.1375(0.3040) Grad: 1.4285 Epoch: [1][9000/40998] Data 0.552 (0.553) Elapsed 128m 1s (remain 455m 6s) Loss: 0.1663(0.3038) Grad: 1.9197 Epoch: [1][9020/40998] Data 0.553 (0.553) Elapsed 128m 18s (remain 454m 49s) Loss: 0.2058(0.3036) Grad: 2.2798 Epoch: [1][9040/40998] Data 0.553 (0.553) Elapsed 128m 35s (remain 454m 32s) Loss: 0.0668(0.3033) Grad: 0.9697 Epoch: [1][9060/40998] Data 0.553 (0.553) Elapsed 128m 52s (remain 454m 15s) Loss: 0.3560(0.3031) Grad: 3.1274 Epoch: [1][9080/40998] Data 0.553 (0.553) Elapsed 129m 9s (remain 453m 57s) Loss: 0.3393(0.3029) Grad: 3.6336 Epoch: [1][9100/40998] Data 0.553 (0.553) Elapsed 129m 26s (remain 453m 40s) Loss: 0.3012(0.3028) Grad: 3.0730 Epoch: [1][9120/40998] Data 0.553 (0.553) Elapsed 129m 43s (remain 453m 23s) Loss: 0.2861(0.3026) Grad: 2.6326 Epoch: [1][9140/40998] Data 0.553 (0.553) Elapsed 130m 0s (remain 453m 6s) Loss: 0.0954(0.3024) Grad: 1.6698 Epoch: [1][9160/40998] Data 0.552 (0.553) Elapsed 130m 17s (remain 452m 49s) Loss: 0.1190(0.3023) Grad: 2.1619 Epoch: [1][9180/40998] Data 0.553 (0.553) Elapsed 130m 35s (remain 452m 32s) Loss: 0.1444(0.3020) Grad: 3.1157 Epoch: [1][9200/40998] Data 0.553 (0.553) Elapsed 130m 52s (remain 452m 15s) Loss: 0.1557(0.3019) Grad: 2.7366 Epoch: [1][9220/40998] Data 0.552 (0.553) Elapsed 131m 9s (remain 451m 58s) Loss: 0.1108(0.3017) Grad: 1.9261 Epoch: [1][9240/40998] Data 0.552 (0.553) Elapsed 131m 26s (remain 451m 41s) Loss: 0.1360(0.3014) Grad: 2.8110 Epoch: [1][9260/40998] Data 0.553 (0.553) Elapsed 131m 43s (remain 451m 24s) Loss: 0.1250(0.3012) Grad: 2.0337 Epoch: [1][9280/40998] Data 0.553 (0.553) Elapsed 132m 0s (remain 451m 7s) Loss: 0.1082(0.3010) Grad: 2.1692 Epoch: [1][9300/40998] Data 0.553 (0.553) Elapsed 132m 17s (remain 450m 50s) Loss: 0.0820(0.3007) Grad: 1.3971 Epoch: [1][9320/40998] Data 0.553 (0.553) Elapsed 132m 34s (remain 450m 32s) Loss: 0.1662(0.3006) Grad: 1.3420 Epoch: [1][9340/40998] Data 0.552 (0.553) Elapsed 132m 51s (remain 450m 15s) Loss: 0.0946(0.3004) Grad: 1.7330 Epoch: [1][9360/40998] Data 0.553 (0.553) Elapsed 133m 8s (remain 449m 58s) Loss: 0.4719(0.3004) Grad: 6.2861 Epoch: [1][9380/40998] Data 0.552 (0.553) Elapsed 133m 25s (remain 449m 41s) Loss: 0.1829(0.3002) Grad: 2.3480 Epoch: [1][9400/40998] Data 0.552 (0.553) Elapsed 133m 42s (remain 449m 24s) Loss: 0.2828(0.3001) Grad: 2.8431 Epoch: [1][9420/40998] Data 0.552 (0.553) Elapsed 133m 59s (remain 449m 7s) Loss: 0.2431(0.3000) Grad: 3.3322 Epoch: [1][9440/40998] Data 0.553 (0.553) Elapsed 134m 16s (remain 448m 50s) Loss: 0.0499(0.2997) Grad: 0.7240 Epoch: [1][9460/40998] Data 0.553 (0.553) Elapsed 134m 33s (remain 448m 33s) Loss: 0.1715(0.2995) Grad: 2.2409 Epoch: [1][9480/40998] Data 0.552 (0.553) Elapsed 134m 50s (remain 448m 16s) Loss: 0.3134(0.2995) Grad: 4.8003 Epoch: [1][9500/40998] Data 0.553 (0.553) Elapsed 135m 8s (remain 447m 59s) Loss: 0.4222(0.2993) Grad: 5.7802 Epoch: [1][9520/40998] Data 0.553 (0.553) Elapsed 135m 25s (remain 447m 41s) Loss: 0.1154(0.2990) Grad: 1.1436 Epoch: [1][9540/40998] Data 0.553 (0.553) Elapsed 135m 42s (remain 447m 24s) Loss: 0.3694(0.2989) Grad: 3.3291 Epoch: [1][9560/40998] Data 0.553 (0.553) Elapsed 135m 59s (remain 447m 7s) Loss: 0.2890(0.2987) Grad: 3.4256 Epoch: [1][9580/40998] Data 0.553 (0.553) Elapsed 136m 16s (remain 446m 50s) Loss: 0.2402(0.2986) Grad: 3.1124 Epoch: [1][9600/40998] Data 0.552 (0.553) Elapsed 136m 33s (remain 446m 33s) Loss: 0.1067(0.2984) Grad: 1.8211 Epoch: [1][9620/40998] Data 0.553 (0.553) Elapsed 136m 50s (remain 446m 16s) Loss: 0.2682(0.2982) Grad: 3.0080 Epoch: [1][9640/40998] Data 0.552 (0.553) Elapsed 137m 7s (remain 445m 59s) Loss: 0.3773(0.2979) Grad: 4.7305 Epoch: [1][9660/40998] Data 0.553 (0.553) Elapsed 137m 24s (remain 445m 42s) Loss: 0.0954(0.2977) Grad: 2.7707 Epoch: [1][9680/40998] Data 0.553 (0.553) Elapsed 137m 41s (remain 445m 25s) Loss: 0.2019(0.2975) Grad: 3.0861 Epoch: [1][9700/40998] Data 0.553 (0.553) Elapsed 137m 58s (remain 445m 8s) Loss: 0.0632(0.2974) Grad: 1.4904 Epoch: [1][9720/40998] Data 0.552 (0.553) Elapsed 138m 15s (remain 444m 51s) Loss: 0.1240(0.2972) Grad: 1.7863 Epoch: [1][9740/40998] Data 0.553 (0.553) Elapsed 138m 32s (remain 444m 33s) Loss: 0.1565(0.2970) Grad: 2.2276 Epoch: [1][9760/40998] Data 0.552 (0.553) Elapsed 138m 49s (remain 444m 16s) Loss: 0.1022(0.2968) Grad: 1.7565 Epoch: [1][9780/40998] Data 0.553 (0.553) Elapsed 139m 6s (remain 443m 59s) Loss: 0.2700(0.2966) Grad: 3.1657 Epoch: [1][9800/40998] Data 0.553 (0.553) Elapsed 139m 23s (remain 443m 42s) Loss: 0.1452(0.2963) Grad: 2.3832 Epoch: [1][9820/40998] Data 0.553 (0.553) Elapsed 139m 40s (remain 443m 25s) Loss: 0.2919(0.2961) Grad: 3.4792 Epoch: [1][9840/40998] Data 0.552 (0.553) Elapsed 139m 58s (remain 443m 8s) Loss: 0.6499(0.2961) Grad: 6.6629 Epoch: [1][9860/40998] Data 0.552 (0.553) Elapsed 140m 15s (remain 442m 51s) Loss: 0.4106(0.2960) Grad: 6.1597 Epoch: [1][9880/40998] Data 0.553 (0.553) Elapsed 140m 32s (remain 442m 34s) Loss: 0.1633(0.2958) Grad: 1.2065 Epoch: [1][9900/40998] Data 0.553 (0.553) Elapsed 140m 49s (remain 442m 17s) Loss: 0.1014(0.2957) Grad: 1.2803 Epoch: [1][9920/40998] Data 0.553 (0.553) Elapsed 141m 6s (remain 442m 0s) Loss: 0.7928(0.2955) Grad: 10.4774 Epoch: [1][9940/40998] Data 0.552 (0.553) Elapsed 141m 23s (remain 441m 43s) Loss: 0.3218(0.2954) Grad: 3.6060 Epoch: [1][9960/40998] Data 0.553 (0.553) Elapsed 141m 40s (remain 441m 26s) Loss: 0.5285(0.2953) Grad: 6.1914 Epoch: [1][9980/40998] Data 0.553 (0.553) Elapsed 141m 57s (remain 441m 8s) Loss: 0.0488(0.2951) Grad: 0.7472 Epoch: [1][10000/40998] Data 0.552 (0.553) Elapsed 142m 14s (remain 440m 51s) Loss: 0.1101(0.2949) Grad: 1.6955 Epoch: [1][10020/40998] Data 0.553 (0.553) Elapsed 142m 31s (remain 440m 34s) Loss: 0.1418(0.2947) Grad: 2.3735 Epoch: [1][10040/40998] Data 0.552 (0.553) Elapsed 142m 48s (remain 440m 17s) Loss: 0.1886(0.2945) Grad: 2.6143 Epoch: [1][10060/40998] Data 0.553 (0.553) Elapsed 143m 5s (remain 440m 0s) Loss: 0.1682(0.2943) Grad: 3.1873 Epoch: [1][10080/40998] Data 0.553 (0.553) Elapsed 143m 22s (remain 439m 43s) Loss: 0.1563(0.2941) Grad: 1.9813 Epoch: [1][10100/40998] Data 0.552 (0.553) Elapsed 143m 39s (remain 439m 26s) Loss: 0.1523(0.2940) Grad: 3.3176 Epoch: [1][10120/40998] Data 0.553 (0.553) Elapsed 143m 56s (remain 439m 9s) Loss: 0.1653(0.2939) Grad: 2.4759 Epoch: [1][10140/40998] Data 0.553 (0.553) Elapsed 144m 13s (remain 438m 52s) Loss: 0.2374(0.2937) Grad: 2.8077 Epoch: [1][10160/40998] Data 0.553 (0.553) Elapsed 144m 31s (remain 438m 35s) Loss: 0.1224(0.2936) Grad: 1.2285 Epoch: [1][10180/40998] Data 0.552 (0.553) Elapsed 144m 48s (remain 438m 18s) Loss: 0.4042(0.2934) Grad: 4.0073 Epoch: [1][10200/40998] Data 0.554 (0.553) Elapsed 145m 5s (remain 438m 1s) Loss: 0.4165(0.2933) Grad: 5.7505 Epoch: [1][10220/40998] Data 0.553 (0.553) Elapsed 145m 22s (remain 437m 44s) Loss: 0.1038(0.2931) Grad: 1.4230 Epoch: [1][10240/40998] Data 0.553 (0.553) Elapsed 145m 39s (remain 437m 26s) Loss: 0.0256(0.2929) Grad: 0.3938 Epoch: [1][10260/40998] Data 0.552 (0.553) Elapsed 145m 56s (remain 437m 9s) Loss: 0.2263(0.2928) Grad: 3.2504 Epoch: [1][10280/40998] Data 0.553 (0.553) Elapsed 146m 13s (remain 436m 52s) Loss: 0.5224(0.2927) Grad: 4.2521 Epoch: [1][10300/40998] Data 0.553 (0.553) Elapsed 146m 30s (remain 436m 35s) Loss: 0.2019(0.2925) Grad: 1.6706 Epoch: [1][10320/40998] Data 0.553 (0.553) Elapsed 146m 47s (remain 436m 18s) Loss: 0.3010(0.2924) Grad: 3.1065 Epoch: [1][10340/40998] Data 0.553 (0.553) Elapsed 147m 4s (remain 436m 1s) Loss: 0.0759(0.2922) Grad: 0.8501 Epoch: [1][10360/40998] Data 0.553 (0.553) Elapsed 147m 21s (remain 435m 44s) Loss: 0.4984(0.2920) Grad: 5.3230 Epoch: [1][10380/40998] Data 0.553 (0.553) Elapsed 147m 38s (remain 435m 27s) Loss: 0.1216(0.2918) Grad: 2.1360 Epoch: [1][10400/40998] Data 0.553 (0.553) Elapsed 147m 55s (remain 435m 10s) Loss: 0.1922(0.2917) Grad: 2.5480 Epoch: [1][10420/40998] Data 0.553 (0.553) Elapsed 148m 12s (remain 434m 53s) Loss: 0.1770(0.2915) Grad: 2.7236 Epoch: [1][10440/40998] Data 0.553 (0.553) Elapsed 148m 29s (remain 434m 36s) Loss: 0.1676(0.2913) Grad: 2.0044 Epoch: [1][10460/40998] Data 0.552 (0.553) Elapsed 148m 46s (remain 434m 19s) Loss: 0.6212(0.2912) Grad: 6.1231 Epoch: [1][10480/40998] Data 0.553 (0.553) Elapsed 149m 4s (remain 434m 1s) Loss: 0.4213(0.2909) Grad: 3.6924 Epoch: [1][10500/40998] Data 0.552 (0.553) Elapsed 149m 21s (remain 433m 44s) Loss: 0.2666(0.2907) Grad: 4.4078 Epoch: [1][10520/40998] Data 0.552 (0.553) Elapsed 149m 38s (remain 433m 27s) Loss: 0.0554(0.2906) Grad: 0.9174 Epoch: [1][10540/40998] Data 0.553 (0.553) Elapsed 149m 55s (remain 433m 10s) Loss: 0.2599(0.2904) Grad: 3.5663 Epoch: [1][10560/40998] Data 0.553 (0.553) Elapsed 150m 12s (remain 432m 53s) Loss: 0.0923(0.2903) Grad: 1.8463 Epoch: [1][10580/40998] Data 0.553 (0.553) Elapsed 150m 29s (remain 432m 36s) Loss: 0.1121(0.2901) Grad: 1.5749 Epoch: [1][10600/40998] Data 0.553 (0.553) Elapsed 150m 46s (remain 432m 19s) Loss: 0.0971(0.2899) Grad: 0.8963 Epoch: [1][10620/40998] Data 0.553 (0.553) Elapsed 151m 3s (remain 432m 2s) Loss: 0.3209(0.2898) Grad: 4.1762 Epoch: [1][10640/40998] Data 0.552 (0.553) Elapsed 151m 20s (remain 431m 45s) Loss: 0.3506(0.2896) Grad: 3.3387 Epoch: [1][10660/40998] Data 0.554 (0.553) Elapsed 151m 37s (remain 431m 28s) Loss: 0.3863(0.2895) Grad: 3.9963 Epoch: [1][10680/40998] Data 0.553 (0.553) Elapsed 151m 54s (remain 431m 11s) Loss: 0.2034(0.2894) Grad: 2.1361 Epoch: [1][10700/40998] Data 0.553 (0.553) Elapsed 152m 11s (remain 430m 54s) Loss: 0.1616(0.2892) Grad: 1.7826 Epoch: [1][10720/40998] Data 0.552 (0.553) Elapsed 152m 28s (remain 430m 36s) Loss: 0.4334(0.2891) Grad: 3.1840 Epoch: [1][10740/40998] Data 0.553 (0.553) Elapsed 152m 45s (remain 430m 19s) Loss: 0.0795(0.2889) Grad: 1.6903 Epoch: [1][10760/40998] Data 0.553 (0.553) Elapsed 153m 2s (remain 430m 2s) Loss: 0.4079(0.2888) Grad: 3.5098 Epoch: [1][10780/40998] Data 0.553 (0.553) Elapsed 153m 19s (remain 429m 45s) Loss: 0.1234(0.2885) Grad: 1.7600 Epoch: [1][10800/40998] Data 0.552 (0.553) Elapsed 153m 37s (remain 429m 28s) Loss: 0.1356(0.2883) Grad: 1.7237 Epoch: [1][10820/40998] Data 0.553 (0.553) Elapsed 153m 54s (remain 429m 11s) Loss: 0.3968(0.2882) Grad: 5.0036 Epoch: [1][10840/40998] Data 0.552 (0.553) Elapsed 154m 11s (remain 428m 54s) Loss: 0.1598(0.2881) Grad: 2.9145 Epoch: [1][10860/40998] Data 0.552 (0.553) Elapsed 154m 28s (remain 428m 37s) Loss: 0.6568(0.2879) Grad: 4.0339 Epoch: [1][10880/40998] Data 0.553 (0.553) Elapsed 154m 45s (remain 428m 20s) Loss: 0.1365(0.2879) Grad: 2.0833 Epoch: [1][10900/40998] Data 0.553 (0.553) Elapsed 155m 2s (remain 428m 3s) Loss: 0.0967(0.2877) Grad: 1.5512 Epoch: [1][10920/40998] Data 0.553 (0.553) Elapsed 155m 19s (remain 427m 46s) Loss: 0.2174(0.2875) Grad: 2.2061 Epoch: [1][10940/40998] Data 0.553 (0.553) Elapsed 155m 36s (remain 427m 29s) Loss: 0.1685(0.2874) Grad: 2.4892 Epoch: [1][10960/40998] Data 0.552 (0.553) Elapsed 155m 53s (remain 427m 11s) Loss: 0.1538(0.2872) Grad: 2.3856 Epoch: [1][10980/40998] Data 0.553 (0.553) Elapsed 156m 10s (remain 426m 54s) Loss: 0.0939(0.2871) Grad: 1.4951 Epoch: [1][11000/40998] Data 0.553 (0.553) Elapsed 156m 27s (remain 426m 37s) Loss: 0.2898(0.2870) Grad: 3.3033 Epoch: [1][11020/40998] Data 0.553 (0.553) Elapsed 156m 44s (remain 426m 20s) Loss: 0.3718(0.2868) Grad: 2.7753 Epoch: [1][11040/40998] Data 0.553 (0.553) Elapsed 157m 1s (remain 426m 3s) Loss: 0.1946(0.2866) Grad: 2.3912 Epoch: [1][11060/40998] Data 0.552 (0.553) Elapsed 157m 18s (remain 425m 46s) Loss: 0.3590(0.2865) Grad: 4.4216 Epoch: [1][11080/40998] Data 0.553 (0.553) Elapsed 157m 35s (remain 425m 29s) Loss: 0.2536(0.2864) Grad: 2.8354 Epoch: [1][11100/40998] Data 0.553 (0.553) Elapsed 157m 52s (remain 425m 12s) Loss: 0.3320(0.2862) Grad: 4.4136 Epoch: [1][11120/40998] Data 0.553 (0.553) Elapsed 158m 9s (remain 424m 55s) Loss: 0.1996(0.2860) Grad: 1.6666 Epoch: [1][11140/40998] Data 0.553 (0.553) Elapsed 158m 27s (remain 424m 38s) Loss: 0.2161(0.2858) Grad: 2.5799 Epoch: [1][11160/40998] Data 0.553 (0.553) Elapsed 158m 44s (remain 424m 21s) Loss: 0.3692(0.2857) Grad: 3.3311 Epoch: [1][11180/40998] Data 0.553 (0.553) Elapsed 159m 1s (remain 424m 3s) Loss: 0.1598(0.2856) Grad: 1.2819 Epoch: [1][11200/40998] Data 0.553 (0.553) Elapsed 159m 18s (remain 423m 46s) Loss: 0.2326(0.2854) Grad: 1.8353 Epoch: [1][11220/40998] Data 0.552 (0.553) Elapsed 159m 35s (remain 423m 29s) Loss: 0.1767(0.2853) Grad: 3.0297 Epoch: [1][11240/40998] Data 0.553 (0.553) Elapsed 159m 52s (remain 423m 12s) Loss: 0.1292(0.2851) Grad: 1.7293 Epoch: [1][11260/40998] Data 0.553 (0.553) Elapsed 160m 9s (remain 422m 55s) Loss: 0.0568(0.2849) Grad: 1.2657 Epoch: [1][11280/40998] Data 0.553 (0.553) Elapsed 160m 26s (remain 422m 38s) Loss: 0.1996(0.2847) Grad: 3.1683 Epoch: [1][11300/40998] Data 0.553 (0.553) Elapsed 160m 43s (remain 422m 21s) Loss: 0.3915(0.2846) Grad: 4.3228 Epoch: [1][11320/40998] Data 0.552 (0.553) Elapsed 161m 0s (remain 422m 4s) Loss: 0.3222(0.2845) Grad: 3.2798 Epoch: [1][11340/40998] Data 0.553 (0.553) Elapsed 161m 17s (remain 421m 47s) Loss: 0.1508(0.2844) Grad: 1.8556 Epoch: [1][11360/40998] Data 0.553 (0.553) Elapsed 161m 34s (remain 421m 30s) Loss: 0.0613(0.2842) Grad: 0.7567 Epoch: [1][11380/40998] Data 0.553 (0.553) Elapsed 161m 51s (remain 421m 13s) Loss: 0.1617(0.2841) Grad: 2.5667 Epoch: [1][11400/40998] Data 0.553 (0.553) Elapsed 162m 8s (remain 420m 56s) Loss: 0.4391(0.2841) Grad: 4.3833 Epoch: [1][11420/40998] Data 0.553 (0.553) Elapsed 162m 25s (remain 420m 38s) Loss: 0.2194(0.2840) Grad: 2.5412 Epoch: [1][11440/40998] Data 0.553 (0.553) Elapsed 162m 42s (remain 420m 21s) Loss: 0.3348(0.2839) Grad: 4.1984 Epoch: [1][11460/40998] Data 0.552 (0.553) Elapsed 163m 0s (remain 420m 4s) Loss: 0.1438(0.2838) Grad: 1.3391 Epoch: [1][11480/40998] Data 0.553 (0.553) Elapsed 163m 17s (remain 419m 47s) Loss: 0.0572(0.2836) Grad: 0.9536 Epoch: [1][11500/40998] Data 0.553 (0.553) Elapsed 163m 34s (remain 419m 30s) Loss: 0.2534(0.2834) Grad: 3.4075 Epoch: [1][11520/40998] Data 0.553 (0.553) Elapsed 163m 51s (remain 419m 13s) Loss: 0.2573(0.2833) Grad: 3.6560 Epoch: [1][11540/40998] Data 0.553 (0.553) Elapsed 164m 8s (remain 418m 56s) Loss: 0.0579(0.2832) Grad: 1.1304 Epoch: [1][11560/40998] Data 0.553 (0.553) Elapsed 164m 25s (remain 418m 39s) Loss: 0.1019(0.2831) Grad: 1.6954 Epoch: [1][11580/40998] Data 0.553 (0.553) Elapsed 164m 42s (remain 418m 22s) Loss: 0.2979(0.2830) Grad: 5.3529 Epoch: [1][11600/40998] Data 0.553 (0.553) Elapsed 164m 59s (remain 418m 5s) Loss: 0.2278(0.2828) Grad: 1.8658 Epoch: [1][11620/40998] Data 0.553 (0.553) Elapsed 165m 16s (remain 417m 48s) Loss: 0.1379(0.2827) Grad: 2.0557 Epoch: [1][11640/40998] Data 0.553 (0.553) Elapsed 165m 33s (remain 417m 31s) Loss: 0.2468(0.2826) Grad: 2.8676 Epoch: [1][11660/40998] Data 0.553 (0.553) Elapsed 165m 50s (remain 417m 13s) Loss: 0.4265(0.2825) Grad: 5.6323 Epoch: [1][11680/40998] Data 0.553 (0.553) Elapsed 166m 7s (remain 416m 56s) Loss: 0.1056(0.2823) Grad: 1.2849 Epoch: [1][11700/40998] Data 0.552 (0.553) Elapsed 166m 24s (remain 416m 39s) Loss: 0.1276(0.2821) Grad: 1.8242 Epoch: [1][11720/40998] Data 0.554 (0.553) Elapsed 166m 41s (remain 416m 22s) Loss: 0.2268(0.2820) Grad: 3.7482 Epoch: [1][11740/40998] Data 0.553 (0.553) Elapsed 166m 58s (remain 416m 5s) Loss: 0.3198(0.2819) Grad: 3.9428 Epoch: [1][11760/40998] Data 0.552 (0.553) Elapsed 167m 15s (remain 415m 48s) Loss: 0.4441(0.2817) Grad: 5.2166 Epoch: [1][11780/40998] Data 0.553 (0.553) Elapsed 167m 32s (remain 415m 31s) Loss: 0.1980(0.2816) Grad: 2.5334 Epoch: [1][11800/40998] Data 0.552 (0.553) Elapsed 167m 50s (remain 415m 14s) Loss: 0.1651(0.2815) Grad: 2.0166 Epoch: [1][11820/40998] Data 0.553 (0.553) Elapsed 168m 7s (remain 414m 57s) Loss: 0.1535(0.2814) Grad: 2.6256 Epoch: [1][11840/40998] Data 0.552 (0.553) Elapsed 168m 24s (remain 414m 40s) Loss: 0.3905(0.2813) Grad: 4.1221 Epoch: [1][11860/40998] Data 0.553 (0.553) Elapsed 168m 41s (remain 414m 23s) Loss: 0.2685(0.2810) Grad: 3.0302 Epoch: [1][11880/40998] Data 0.553 (0.553) Elapsed 168m 58s (remain 414m 6s) Loss: 0.2195(0.2809) Grad: 1.5101 Epoch: [1][11900/40998] Data 0.553 (0.553) Elapsed 169m 15s (remain 413m 49s) Loss: 0.1416(0.2807) Grad: 1.8085 Epoch: [1][11920/40998] Data 0.552 (0.553) Elapsed 169m 32s (remain 413m 31s) Loss: 0.1171(0.2805) Grad: 2.8623 Epoch: [1][11940/40998] Data 0.553 (0.553) Elapsed 169m 49s (remain 413m 14s) Loss: 0.1078(0.2804) Grad: 2.1971 Epoch: [1][11960/40998] Data 0.553 (0.553) Elapsed 170m 6s (remain 412m 57s) Loss: 0.2241(0.2802) Grad: 3.0881 Epoch: [1][11980/40998] Data 0.552 (0.553) Elapsed 170m 23s (remain 412m 40s) Loss: 0.1337(0.2801) Grad: 3.1350 Epoch: [1][12000/40998] Data 0.553 (0.553) Elapsed 170m 40s (remain 412m 23s) Loss: 0.4619(0.2800) Grad: 4.3302 Epoch: [1][12020/40998] Data 0.553 (0.553) Elapsed 170m 57s (remain 412m 6s) Loss: 0.2204(0.2799) Grad: 2.9538 Epoch: [1][12040/40998] Data 0.553 (0.553) Elapsed 171m 14s (remain 411m 49s) Loss: 0.0712(0.2797) Grad: 0.8668 Epoch: [1][12060/40998] Data 0.553 (0.553) Elapsed 171m 31s (remain 411m 32s) Loss: 0.3250(0.2798) Grad: 4.2233 Epoch: [1][12080/40998] Data 0.553 (0.553) Elapsed 171m 48s (remain 411m 15s) Loss: 0.3344(0.2797) Grad: 2.8720 Epoch: [1][12100/40998] Data 0.552 (0.553) Elapsed 172m 5s (remain 410m 58s) Loss: 0.1912(0.2796) Grad: 2.7060 Epoch: [1][12120/40998] Data 0.552 (0.553) Elapsed 172m 23s (remain 410m 41s) Loss: 0.1051(0.2794) Grad: 1.0307 Epoch: [1][12140/40998] Data 0.553 (0.553) Elapsed 172m 40s (remain 410m 24s) Loss: 0.4672(0.2794) Grad: 3.4981 Epoch: [1][12160/40998] Data 0.553 (0.553) Elapsed 172m 57s (remain 410m 6s) Loss: 0.2009(0.2793) Grad: 2.7937 Epoch: [1][12180/40998] Data 0.552 (0.553) Elapsed 173m 14s (remain 409m 49s) Loss: 0.1070(0.2791) Grad: 1.8084 Epoch: [1][12200/40998] Data 0.553 (0.553) Elapsed 173m 31s (remain 409m 32s) Loss: 0.0850(0.2790) Grad: 1.2566 Epoch: [1][12220/40998] Data 0.553 (0.553) Elapsed 173m 48s (remain 409m 15s) Loss: 0.0818(0.2789) Grad: 1.2756 Epoch: [1][12240/40998] Data 0.553 (0.553) Elapsed 174m 5s (remain 408m 58s) Loss: 0.1323(0.2787) Grad: 3.0956 Epoch: [1][12260/40998] Data 0.553 (0.553) Elapsed 174m 22s (remain 408m 41s) Loss: 0.0608(0.2785) Grad: 1.1836 Epoch: [1][12280/40998] Data 0.553 (0.553) Elapsed 174m 39s (remain 408m 24s) Loss: 0.5868(0.2785) Grad: 5.6230 Epoch: [1][12300/40998] Data 0.553 (0.553) Elapsed 174m 56s (remain 408m 7s) Loss: 0.1592(0.2784) Grad: 2.2664 Epoch: [1][12320/40998] Data 0.553 (0.553) Elapsed 175m 13s (remain 407m 50s) Loss: 0.1328(0.2782) Grad: 1.7508 Epoch: [1][12340/40998] Data 0.553 (0.553) Elapsed 175m 30s (remain 407m 33s) Loss: 0.0962(0.2780) Grad: 1.9179 Epoch: [1][12360/40998] Data 0.553 (0.553) Elapsed 175m 47s (remain 407m 16s) Loss: 0.0645(0.2779) Grad: 0.8233 Epoch: [1][12380/40998] Data 0.553 (0.553) Elapsed 176m 4s (remain 406m 59s) Loss: 0.1937(0.2777) Grad: 2.9432 Epoch: [1][12400/40998] Data 0.552 (0.553) Elapsed 176m 21s (remain 406m 42s) Loss: 0.2210(0.2776) Grad: 3.5599 Epoch: [1][12420/40998] Data 0.553 (0.553) Elapsed 176m 38s (remain 406m 24s) Loss: 0.0612(0.2775) Grad: 1.1432 Epoch: [1][12440/40998] Data 0.552 (0.553) Elapsed 176m 55s (remain 406m 7s) Loss: 0.2375(0.2773) Grad: 3.4121 Epoch: [1][12460/40998] Data 0.553 (0.553) Elapsed 177m 13s (remain 405m 50s) Loss: 0.3726(0.2772) Grad: 2.5110 Epoch: [1][12480/40998] Data 0.553 (0.553) Elapsed 177m 30s (remain 405m 33s) Loss: 0.2202(0.2770) Grad: 3.1691 Epoch: [1][12500/40998] Data 0.553 (0.553) Elapsed 177m 47s (remain 405m 16s) Loss: 0.0823(0.2768) Grad: 1.2501 Epoch: [1][12520/40998] Data 0.553 (0.553) Elapsed 178m 4s (remain 404m 59s) Loss: 0.0437(0.2766) Grad: 0.8163 Epoch: [1][12540/40998] Data 0.552 (0.553) Elapsed 178m 21s (remain 404m 42s) Loss: 0.1465(0.2765) Grad: 1.9858 Epoch: [1][12560/40998] Data 0.553 (0.553) Elapsed 178m 38s (remain 404m 25s) Loss: 0.3101(0.2763) Grad: 4.8959 Epoch: [1][12580/40998] Data 0.553 (0.553) Elapsed 178m 55s (remain 404m 8s) Loss: 0.1404(0.2762) Grad: 1.5938 Epoch: [1][12600/40998] Data 0.553 (0.553) Elapsed 179m 12s (remain 403m 51s) Loss: 0.2195(0.2761) Grad: 2.4182 Epoch: [1][12620/40998] Data 0.552 (0.553) Elapsed 179m 29s (remain 403m 34s) Loss: 0.1222(0.2759) Grad: 2.2186 Epoch: [1][12640/40998] Data 0.552 (0.553) Elapsed 179m 46s (remain 403m 17s) Loss: 0.1959(0.2757) Grad: 2.9698 Epoch: [1][12660/40998] Data 0.553 (0.553) Elapsed 180m 3s (remain 403m 0s) Loss: 0.1455(0.2756) Grad: 3.3279 Epoch: [1][12680/40998] Data 0.553 (0.553) Elapsed 180m 20s (remain 402m 42s) Loss: 0.4566(0.2755) Grad: 5.2471 Epoch: [1][12700/40998] Data 0.553 (0.553) Elapsed 180m 37s (remain 402m 25s) Loss: 0.1578(0.2753) Grad: 2.5220 Epoch: [1][12720/40998] Data 0.552 (0.553) Elapsed 180m 54s (remain 402m 8s) Loss: 0.1604(0.2753) Grad: 1.4084 Epoch: [1][12740/40998] Data 0.553 (0.553) Elapsed 181m 11s (remain 401m 51s) Loss: 0.3172(0.2751) Grad: 3.1501 Epoch: [1][12760/40998] Data 0.553 (0.553) Elapsed 181m 28s (remain 401m 34s) Loss: 0.1468(0.2749) Grad: 2.8432 Epoch: [1][12780/40998] Data 0.553 (0.553) Elapsed 181m 46s (remain 401m 17s) Loss: 0.2035(0.2747) Grad: 1.9159 Epoch: [1][12800/40998] Data 0.553 (0.553) Elapsed 182m 3s (remain 401m 0s) Loss: 0.1491(0.2747) Grad: 2.8245 Epoch: [1][12820/40998] Data 0.553 (0.553) Elapsed 182m 20s (remain 400m 43s) Loss: 0.1649(0.2745) Grad: 1.7007 Epoch: [1][12840/40998] Data 0.553 (0.553) Elapsed 182m 37s (remain 400m 26s) Loss: 0.5324(0.2745) Grad: 6.4621 Epoch: [1][12860/40998] Data 0.553 (0.553) Elapsed 182m 54s (remain 400m 9s) Loss: 0.1526(0.2744) Grad: 2.3720 Epoch: [1][12880/40998] Data 0.553 (0.553) Elapsed 183m 11s (remain 399m 52s) Loss: 0.3317(0.2743) Grad: 3.5787 Epoch: [1][12900/40998] Data 0.552 (0.553) Elapsed 183m 28s (remain 399m 35s) Loss: 0.4026(0.2744) Grad: 3.1777 Epoch: [1][12920/40998] Data 0.553 (0.553) Elapsed 183m 45s (remain 399m 18s) Loss: 0.2084(0.2742) Grad: 1.9413 Epoch: [1][12940/40998] Data 0.553 (0.553) Elapsed 184m 2s (remain 399m 0s) Loss: 0.1895(0.2742) Grad: 1.7200 Epoch: [1][12960/40998] Data 0.553 (0.553) Elapsed 184m 19s (remain 398m 43s) Loss: 0.0665(0.2741) Grad: 1.7799 Epoch: [1][12980/40998] Data 0.554 (0.553) Elapsed 184m 36s (remain 398m 26s) Loss: 0.2867(0.2740) Grad: 2.7899 Epoch: [1][13000/40998] Data 0.552 (0.553) Elapsed 184m 53s (remain 398m 9s) Loss: 0.3508(0.2738) Grad: 4.1978 Epoch: [1][13020/40998] Data 0.553 (0.553) Elapsed 185m 10s (remain 397m 52s) Loss: 0.1074(0.2737) Grad: 1.8525 Epoch: [1][13040/40998] Data 0.552 (0.553) Elapsed 185m 27s (remain 397m 35s) Loss: 0.1009(0.2736) Grad: 1.8458 Epoch: [1][13060/40998] Data 0.553 (0.553) Elapsed 185m 44s (remain 397m 18s) Loss: 0.0428(0.2734) Grad: 0.7681 Epoch: [1][13080/40998] Data 0.553 (0.553) Elapsed 186m 1s (remain 397m 1s) Loss: 0.5188(0.2733) Grad: 4.2225 Epoch: [1][13100/40998] Data 0.552 (0.553) Elapsed 186m 19s (remain 396m 44s) Loss: 0.1424(0.2731) Grad: 2.8646 Epoch: [1][13120/40998] Data 0.553 (0.553) Elapsed 186m 36s (remain 396m 27s) Loss: 0.1236(0.2730) Grad: 1.4198 Epoch: [1][13140/40998] Data 0.552 (0.553) Elapsed 186m 53s (remain 396m 10s) Loss: 0.2005(0.2730) Grad: 2.7191 Epoch: [1][13160/40998] Data 0.553 (0.553) Elapsed 187m 10s (remain 395m 53s) Loss: 0.1143(0.2728) Grad: 1.9377 Epoch: [1][13180/40998] Data 0.552 (0.553) Elapsed 187m 27s (remain 395m 36s) Loss: 0.0813(0.2727) Grad: 1.9954 Epoch: [1][13200/40998] Data 0.553 (0.553) Elapsed 187m 44s (remain 395m 18s) Loss: 0.2514(0.2725) Grad: 5.0005 Epoch: [1][13220/40998] Data 0.553 (0.553) Elapsed 188m 1s (remain 395m 1s) Loss: 0.1540(0.2723) Grad: 3.1987 Epoch: [1][13240/40998] Data 0.553 (0.553) Elapsed 188m 18s (remain 394m 44s) Loss: 0.1607(0.2723) Grad: 2.1414 Epoch: [1][13260/40998] Data 0.553 (0.553) Elapsed 188m 35s (remain 394m 27s) Loss: 0.0780(0.2722) Grad: 1.3876 Epoch: [1][13280/40998] Data 0.553 (0.553) Elapsed 188m 52s (remain 394m 10s) Loss: 0.0738(0.2720) Grad: 1.4101 Epoch: [1][13300/40998] Data 0.553 (0.553) Elapsed 189m 9s (remain 393m 53s) Loss: 0.2876(0.2718) Grad: 2.8325 Epoch: [1][13320/40998] Data 0.553 (0.553) Elapsed 189m 26s (remain 393m 36s) Loss: 0.1924(0.2717) Grad: 1.5424 Epoch: [1][13340/40998] Data 0.553 (0.553) Elapsed 189m 43s (remain 393m 19s) Loss: 0.4726(0.2716) Grad: 3.5446 Epoch: [1][13360/40998] Data 0.553 (0.553) Elapsed 190m 0s (remain 393m 2s) Loss: 0.0591(0.2714) Grad: 1.6184 Epoch: [1][13380/40998] Data 0.553 (0.553) Elapsed 190m 17s (remain 392m 45s) Loss: 0.1623(0.2714) Grad: 2.6068 Epoch: [1][13400/40998] Data 0.553 (0.553) Elapsed 190m 34s (remain 392m 28s) Loss: 0.3665(0.2712) Grad: 3.1257 Epoch: [1][13420/40998] Data 0.553 (0.553) Elapsed 190m 52s (remain 392m 11s) Loss: 0.2133(0.2710) Grad: 2.2932 Epoch: [1][13440/40998] Data 0.553 (0.553) Elapsed 191m 9s (remain 391m 54s) Loss: 0.5188(0.2709) Grad: 4.5501 Epoch: [1][13460/40998] Data 0.553 (0.553) Elapsed 191m 26s (remain 391m 37s) Loss: 0.0752(0.2707) Grad: 1.2922 Epoch: [1][13480/40998] Data 0.553 (0.553) Elapsed 191m 43s (remain 391m 19s) Loss: 0.1093(0.2706) Grad: 1.7965 Epoch: [1][13500/40998] Data 0.553 (0.553) Elapsed 192m 0s (remain 391m 2s) Loss: 0.0759(0.2705) Grad: 1.4157 Epoch: [1][13520/40998] Data 0.553 (0.553) Elapsed 192m 17s (remain 390m 45s) Loss: 0.1198(0.2704) Grad: 1.5453 Epoch: [1][13540/40998] Data 0.552 (0.553) Elapsed 192m 34s (remain 390m 28s) Loss: 0.2351(0.2702) Grad: 2.8289 Epoch: [1][13560/40998] Data 0.553 (0.553) Elapsed 192m 51s (remain 390m 11s) Loss: 0.0649(0.2701) Grad: 1.0979 Epoch: [1][13580/40998] Data 0.553 (0.553) Elapsed 193m 8s (remain 389m 54s) Loss: 0.2389(0.2699) Grad: 2.1562 Epoch: [1][13600/40998] Data 0.553 (0.553) Elapsed 193m 25s (remain 389m 37s) Loss: 0.0966(0.2698) Grad: 1.6538 Epoch: [1][13620/40998] Data 0.553 (0.553) Elapsed 193m 42s (remain 389m 20s) Loss: 0.0821(0.2696) Grad: 1.9282 Epoch: [1][13640/40998] Data 0.552 (0.553) Elapsed 193m 59s (remain 389m 3s) Loss: 0.2066(0.2695) Grad: 3.0901 Epoch: [1][13660/40998] Data 0.553 (0.553) Elapsed 194m 16s (remain 388m 46s) Loss: 0.1418(0.2693) Grad: 2.2154 Epoch: [1][13680/40998] Data 0.553 (0.553) Elapsed 194m 33s (remain 388m 29s) Loss: 0.3653(0.2692) Grad: 4.4659 Epoch: [1][13700/40998] Data 0.553 (0.553) Elapsed 194m 50s (remain 388m 12s) Loss: 0.2434(0.2691) Grad: 3.1364 Epoch: [1][13720/40998] Data 0.553 (0.553) Elapsed 195m 7s (remain 387m 55s) Loss: 0.1268(0.2690) Grad: 1.1831 Epoch: [1][13740/40998] Data 0.552 (0.553) Elapsed 195m 24s (remain 387m 38s) Loss: 0.1555(0.2689) Grad: 1.7581 Epoch: [1][13760/40998] Data 0.553 (0.553) Elapsed 195m 42s (remain 387m 20s) Loss: 0.0566(0.2688) Grad: 1.1725 Epoch: [1][13780/40998] Data 0.553 (0.553) Elapsed 195m 59s (remain 387m 3s) Loss: 0.0174(0.2687) Grad: 0.2886 Epoch: [1][13800/40998] Data 0.552 (0.553) Elapsed 196m 16s (remain 386m 46s) Loss: 0.2373(0.2686) Grad: 2.7507 Epoch: [1][13820/40998] Data 0.553 (0.553) Elapsed 196m 33s (remain 386m 29s) Loss: 0.4093(0.2685) Grad: 2.8707 Epoch: [1][13840/40998] Data 0.553 (0.553) Elapsed 196m 50s (remain 386m 12s) Loss: 0.1244(0.2684) Grad: 2.1606 Epoch: [1][13860/40998] Data 0.552 (0.553) Elapsed 197m 7s (remain 385m 55s) Loss: 0.1576(0.2682) Grad: 1.5257 Epoch: [1][13880/40998] Data 0.553 (0.553) Elapsed 197m 24s (remain 385m 38s) Loss: 0.3625(0.2681) Grad: 6.8079 Epoch: [1][13900/40998] Data 0.553 (0.553) Elapsed 197m 41s (remain 385m 21s) Loss: 0.2467(0.2680) Grad: 2.1558 Epoch: [1][13920/40998] Data 0.553 (0.553) Elapsed 197m 58s (remain 385m 4s) Loss: 0.1414(0.2679) Grad: 2.4116 Epoch: [1][13940/40998] Data 0.552 (0.553) Elapsed 198m 15s (remain 384m 47s) Loss: 0.1597(0.2678) Grad: 2.4256 Epoch: [1][13960/40998] Data 0.553 (0.553) Elapsed 198m 32s (remain 384m 30s) Loss: 0.1836(0.2677) Grad: 1.8756 Epoch: [1][13980/40998] Data 0.553 (0.553) Elapsed 198m 49s (remain 384m 13s) Loss: 0.1919(0.2675) Grad: 2.7549 Epoch: [1][14000/40998] Data 0.553 (0.553) Elapsed 199m 6s (remain 383m 56s) Loss: 0.0990(0.2675) Grad: 1.5303 Epoch: [1][14020/40998] Data 0.552 (0.553) Elapsed 199m 23s (remain 383m 38s) Loss: 0.2932(0.2674) Grad: 2.5059 Epoch: [1][14040/40998] Data 0.553 (0.553) Elapsed 199m 40s (remain 383m 21s) Loss: 0.6504(0.2673) Grad: 4.8618 Epoch: [1][14060/40998] Data 0.553 (0.553) Elapsed 199m 57s (remain 383m 4s) Loss: 0.0658(0.2671) Grad: 1.0269 Epoch: [1][14080/40998] Data 0.552 (0.553) Elapsed 200m 15s (remain 382m 47s) Loss: 0.1637(0.2669) Grad: 3.0674 Epoch: [1][14100/40998] Data 0.553 (0.553) Elapsed 200m 32s (remain 382m 30s) Loss: 0.3702(0.2668) Grad: 3.6453 Epoch: [1][14120/40998] Data 0.553 (0.553) Elapsed 200m 49s (remain 382m 13s) Loss: 0.0752(0.2667) Grad: 1.3577 Epoch: [1][14140/40998] Data 0.553 (0.553) Elapsed 201m 6s (remain 381m 56s) Loss: 0.1755(0.2665) Grad: 2.7587 Epoch: [1][14160/40998] Data 0.553 (0.553) Elapsed 201m 23s (remain 381m 39s) Loss: 0.4580(0.2665) Grad: 5.6603 Epoch: [1][14180/40998] Data 0.553 (0.553) Elapsed 201m 40s (remain 381m 22s) Loss: 0.0533(0.2664) Grad: 0.6712 Epoch: [1][14200/40998] Data 0.552 (0.553) Elapsed 201m 57s (remain 381m 5s) Loss: 0.2404(0.2662) Grad: 2.2382 Epoch: [1][14220/40998] Data 0.553 (0.553) Elapsed 202m 14s (remain 380m 48s) Loss: 0.1961(0.2661) Grad: 2.9530 Epoch: [1][14240/40998] Data 0.553 (0.553) Elapsed 202m 31s (remain 380m 31s) Loss: 0.2574(0.2660) Grad: 2.5785 Epoch: [1][14260/40998] Data 0.553 (0.553) Elapsed 202m 48s (remain 380m 14s) Loss: 0.2457(0.2659) Grad: 3.2384 Epoch: [1][14280/40998] Data 0.553 (0.553) Elapsed 203m 5s (remain 379m 57s) Loss: 0.1459(0.2658) Grad: 2.2444 Epoch: [1][14300/40998] Data 0.553 (0.553) Elapsed 203m 22s (remain 379m 40s) Loss: 0.0841(0.2657) Grad: 1.2871 Epoch: [1][14320/40998] Data 0.553 (0.553) Elapsed 203m 39s (remain 379m 22s) Loss: 0.1690(0.2656) Grad: 1.7188 Epoch: [1][14340/40998] Data 0.553 (0.553) Elapsed 203m 56s (remain 379m 5s) Loss: 0.1390(0.2655) Grad: 1.7920 Epoch: [1][14360/40998] Data 0.553 (0.553) Elapsed 204m 13s (remain 378m 48s) Loss: 0.0463(0.2653) Grad: 1.2446 Epoch: [1][14380/40998] Data 0.553 (0.553) Elapsed 204m 30s (remain 378m 31s) Loss: 0.1054(0.2652) Grad: 1.5822 Epoch: [1][14400/40998] Data 0.553 (0.553) Elapsed 204m 48s (remain 378m 14s) Loss: 0.2709(0.2651) Grad: 2.1543 Epoch: [1][14420/40998] Data 0.553 (0.553) Elapsed 205m 5s (remain 377m 57s) Loss: 0.2199(0.2650) Grad: 2.4892 Epoch: [1][14440/40998] Data 0.553 (0.553) Elapsed 205m 22s (remain 377m 40s) Loss: 0.0687(0.2649) Grad: 0.7389 Epoch: [1][14460/40998] Data 0.553 (0.553) Elapsed 205m 39s (remain 377m 23s) Loss: 0.2767(0.2648) Grad: 4.8113 Epoch: [1][14480/40998] Data 0.553 (0.553) Elapsed 205m 56s (remain 377m 6s) Loss: 0.4241(0.2647) Grad: 3.4242 Epoch: [1][14500/40998] Data 0.553 (0.553) Elapsed 206m 13s (remain 376m 49s) Loss: 0.0687(0.2646) Grad: 1.1305 Epoch: [1][14520/40998] Data 0.553 (0.553) Elapsed 206m 30s (remain 376m 32s) Loss: 0.0617(0.2645) Grad: 1.2359 Epoch: [1][14540/40998] Data 0.553 (0.553) Elapsed 206m 47s (remain 376m 15s) Loss: 0.1282(0.2644) Grad: 2.7051 Epoch: [1][14560/40998] Data 0.552 (0.553) Elapsed 207m 4s (remain 375m 58s) Loss: 0.1404(0.2643) Grad: 2.1054 Epoch: [1][14580/40998] Data 0.553 (0.553) Elapsed 207m 21s (remain 375m 41s) Loss: 0.0817(0.2642) Grad: 1.1332 Epoch: [1][14600/40998] Data 0.553 (0.553) Elapsed 207m 38s (remain 375m 23s) Loss: 0.2427(0.2641) Grad: 3.0237 Epoch: [1][14620/40998] Data 0.552 (0.553) Elapsed 207m 55s (remain 375m 6s) Loss: 0.1971(0.2639) Grad: 4.5254 Epoch: [1][14640/40998] Data 0.553 (0.553) Elapsed 208m 12s (remain 374m 49s) Loss: 0.4228(0.2639) Grad: 2.9514 Epoch: [1][14660/40998] Data 0.552 (0.553) Elapsed 208m 29s (remain 374m 32s) Loss: 0.1687(0.2637) Grad: 3.0671 Epoch: [1][14680/40998] Data 0.553 (0.553) Elapsed 208m 46s (remain 374m 15s) Loss: 0.1871(0.2636) Grad: 1.8404 Epoch: [1][14700/40998] Data 0.551 (0.553) Elapsed 209m 4s (remain 373m 58s) Loss: 0.1380(0.2635) Grad: 1.9553 Epoch: [1][14720/40998] Data 0.553 (0.553) Elapsed 209m 21s (remain 373m 41s) Loss: 0.3043(0.2634) Grad: 2.0631 Epoch: [1][14740/40998] Data 0.552 (0.553) Elapsed 209m 38s (remain 373m 24s) Loss: 0.2329(0.2633) Grad: 1.7245 Epoch: [1][14760/40998] Data 0.553 (0.553) Elapsed 209m 55s (remain 373m 7s) Loss: 0.0757(0.2632) Grad: 0.9867 Epoch: [1][14780/40998] Data 0.553 (0.553) Elapsed 210m 12s (remain 372m 50s) Loss: 0.2747(0.2630) Grad: 4.1493 Epoch: [1][14800/40998] Data 0.552 (0.553) Elapsed 210m 29s (remain 372m 33s) Loss: 0.1762(0.2629) Grad: 2.4743 Epoch: [1][14820/40998] Data 0.553 (0.553) Elapsed 210m 46s (remain 372m 16s) Loss: 0.1532(0.2627) Grad: 2.4626 Epoch: [1][14840/40998] Data 0.553 (0.553) Elapsed 211m 3s (remain 371m 59s) Loss: 0.0270(0.2626) Grad: 0.4831 Epoch: [1][14860/40998] Data 0.553 (0.553) Elapsed 211m 20s (remain 371m 42s) Loss: 0.0933(0.2625) Grad: 1.0528 Epoch: [1][14880/40998] Data 0.554 (0.553) Elapsed 211m 37s (remain 371m 25s) Loss: 0.1074(0.2624) Grad: 1.3818 Epoch: [1][14900/40998] Data 0.552 (0.553) Elapsed 211m 54s (remain 371m 8s) Loss: 0.1552(0.2623) Grad: 2.4732 Epoch: [1][14920/40998] Data 0.553 (0.553) Elapsed 212m 11s (remain 370m 50s) Loss: 0.0755(0.2622) Grad: 1.0346 Epoch: [1][14940/40998] Data 0.553 (0.553) Elapsed 212m 28s (remain 370m 33s) Loss: 0.5354(0.2621) Grad: 9.7909 Epoch: [1][14960/40998] Data 0.553 (0.553) Elapsed 212m 45s (remain 370m 16s) Loss: 0.2245(0.2620) Grad: 4.2239 Epoch: [1][14980/40998] Data 0.553 (0.553) Elapsed 213m 2s (remain 369m 59s) Loss: 0.2694(0.2619) Grad: 3.8000 Epoch: [1][15000/40998] Data 0.553 (0.553) Elapsed 213m 20s (remain 369m 42s) Loss: 0.2818(0.2618) Grad: 2.8980 Epoch: [1][15020/40998] Data 0.553 (0.553) Elapsed 213m 37s (remain 369m 25s) Loss: 0.2178(0.2617) Grad: 2.7797 Epoch: [1][15040/40998] Data 0.553 (0.553) Elapsed 213m 54s (remain 369m 8s) Loss: 0.2140(0.2616) Grad: 3.3398 Epoch: [1][15060/40998] Data 0.553 (0.553) Elapsed 214m 11s (remain 368m 51s) Loss: 0.0599(0.2614) Grad: 1.0378 Epoch: [1][15080/40998] Data 0.553 (0.553) Elapsed 214m 28s (remain 368m 34s) Loss: 0.3475(0.2612) Grad: 4.3085 Epoch: [1][15100/40998] Data 0.553 (0.553) Elapsed 214m 45s (remain 368m 17s) Loss: 0.2737(0.2612) Grad: 2.5683 Epoch: [1][15120/40998] Data 0.553 (0.553) Elapsed 215m 2s (remain 368m 0s) Loss: 0.2075(0.2611) Grad: 4.2108 Epoch: [1][15260/40998] Data 0.552 (0.553) Elapsed 217m 1s (remain 366m 0s) Loss: 0.3030(0.2605) Grad: 2.5120 Epoch: [1][15280/40998] Data 0.553 (0.553) Elapsed 217m 18s (remain 365m 43s) Loss: 0.4748(0.2603) Grad: 4.0414 Epoch: [1][15300/40998] Data 0.553 (0.553) Elapsed 217m 35s (remain 365m 26s) Loss: 0.1179(0.2602) Grad: 1.6243 Epoch: [1][15320/40998] Data 0.552 (0.553) Elapsed 217m 53s (remain 365m 9s) Loss: 0.2600(0.2600) Grad: 2.6786 Epoch: [1][15340/40998] Data 0.553 (0.553) Elapsed 218m 10s (remain 364m 52s) Loss: 0.0788(0.2600) Grad: 1.1963 Epoch: [1][15360/40998] Data 0.552 (0.553) Elapsed 218m 27s (remain 364m 35s) Loss: 0.1258(0.2599) Grad: 2.2987 Epoch: [1][15380/40998] Data 0.553 (0.553) Elapsed 218m 44s (remain 364m 18s) Loss: 0.0321(0.2598) Grad: 0.5624 Epoch: [1][15440/40998] Data 0.553 (0.553) Elapsed 219m 35s (remain 363m 27s) Loss: 0.0484(0.2593) Grad: 0.8680 Epoch: [1][15460/40998] Data 0.552 (0.553) Elapsed 219m 52s (remain 363m 10s) Loss: 0.2398(0.2592) Grad: 1.9606 Epoch: [1][15480/40998] Data 0.553 (0.553) Elapsed 220m 9s (remain 362m 52s) Loss: 0.2526(0.2591) Grad: 3.7836 Epoch: [1][15500/40998] Data 0.553 (0.553) Elapsed 220m 26s (remain 362m 35s) Loss: 0.0701(0.2590) Grad: 0.9821 Epoch: [1][15520/40998] Data 0.553 (0.553) Elapsed 220m 43s (remain 362m 18s) Loss: 0.3168(0.2589) Grad: 3.4109 Epoch: [1][15540/40998] Data 0.553 (0.553) Elapsed 221m 0s (remain 362m 1s) Loss: 0.1052(0.2588) Grad: 2.0457 Epoch: [1][15560/40998] Data 0.553 (0.553) Elapsed 221m 17s (remain 361m 44s) Loss: 0.0465(0.2587) Grad: 0.8731 Epoch: [1][15580/40998] Data 0.553 (0.553) Elapsed 221m 34s (remain 361m 27s) Loss: 0.1282(0.2586) Grad: 2.6126 Epoch: [1][15600/40998] Data 0.553 (0.553) Elapsed 221m 51s (remain 361m 10s) Loss: 0.1525(0.2585) Grad: 1.3204 Epoch: [1][15620/40998] Data 0.554 (0.553) Elapsed 222m 8s (remain 360m 53s) Loss: 0.1702(0.2583) Grad: 2.2571 Epoch: [1][15640/40998] Data 0.553 (0.553) Elapsed 222m 25s (remain 360m 36s) Loss: 0.0587(0.2581) Grad: 0.8714 Epoch: [1][15660/40998] Data 0.553 (0.553) Elapsed 222m 43s (remain 360m 19s) Loss: 0.1006(0.2580) Grad: 1.6794 Epoch: [1][15680/40998] Data 0.553 (0.553) Elapsed 223m 0s (remain 360m 2s) Loss: 0.0593(0.2579) Grad: 1.2584 Epoch: [1][15700/40998] Data 0.553 (0.553) Elapsed 223m 17s (remain 359m 45s) Loss: 0.0812(0.2578) Grad: 1.9852 Epoch: [1][15720/40998] Data 0.554 (0.553) Elapsed 223m 34s (remain 359m 28s) Loss: 0.2763(0.2577) Grad: 5.4404 Epoch: [1][15740/40998] Data 0.553 (0.553) Elapsed 223m 51s (remain 359m 11s) Loss: 0.0817(0.2576) Grad: 1.4999 Epoch: [1][15760/40998] Data 0.553 (0.553) Elapsed 224m 8s (remain 358m 53s) Loss: 0.2203(0.2575) Grad: 1.8733 Epoch: [1][15780/40998] Data 0.551 (0.553) Elapsed 224m 25s (remain 358m 36s) Loss: 0.1179(0.2574) Grad: 2.0530 Epoch: [1][15800/40998] Data 0.553 (0.553) Elapsed 224m 42s (remain 358m 19s) Loss: 0.4213(0.2573) Grad: 3.6555 Epoch: [1][15820/40998] Data 0.553 (0.553) Elapsed 224m 59s (remain 358m 2s) Loss: 0.1180(0.2572) Grad: 1.4346 Epoch: [1][15840/40998] Data 0.552 (0.553) Elapsed 225m 16s (remain 357m 45s) Loss: 0.2150(0.2571) Grad: 2.6488 Epoch: [1][15860/40998] Data 0.553 (0.553) Elapsed 225m 33s (remain 357m 28s) Loss: 0.1406(0.2570) Grad: 2.6654 Epoch: [1][15880/40998] Data 0.553 (0.553) Elapsed 225m 50s (remain 357m 11s) Loss: 0.1001(0.2569) Grad: 2.0635 Epoch: [1][15900/40998] Data 0.553 (0.553) Elapsed 226m 7s (remain 356m 54s) Loss: 0.1162(0.2568) Grad: 1.2938 Epoch: [1][15920/40998] Data 0.553 (0.553) Elapsed 226m 24s (remain 356m 37s) Loss: 0.2432(0.2567) Grad: 2.1032 Epoch: [1][15940/40998] Data 0.553 (0.553) Elapsed 226m 41s (remain 356m 20s) Loss: 0.1921(0.2566) Grad: 2.4337 Epoch: [1][15960/40998] Data 0.552 (0.553) Elapsed 226m 58s (remain 356m 3s) Loss: 0.1593(0.2565) Grad: 1.9173 Epoch: [1][15980/40998] Data 0.553 (0.553) Elapsed 227m 16s (remain 355m 46s) Loss: 0.1437(0.2564) Grad: 2.6238 Epoch: [1][16000/40998] Data 0.553 (0.553) Elapsed 227m 33s (remain 355m 29s) Loss: 0.0911(0.2563) Grad: 1.4334 Epoch: [1][16020/40998] Data 0.552 (0.553) Elapsed 227m 50s (remain 355m 11s) Loss: 0.0852(0.2562) Grad: 1.9538 Epoch: [1][16040/40998] Data 0.553 (0.553) Elapsed 228m 7s (remain 354m 54s) Loss: 0.0798(0.2561) Grad: 1.8507 Epoch: [1][16060/40998] Data 0.551 (0.553) Elapsed 228m 24s (remain 354m 37s) Loss: 0.1155(0.2560) Grad: 3.8219 Epoch: [1][16080/40998] Data 0.553 (0.553) Elapsed 228m 41s (remain 354m 20s) Loss: 0.0496(0.2559) Grad: 0.8304 Epoch: [1][16100/40998] Data 0.553 (0.553) Elapsed 228m 58s (remain 354m 3s) Loss: 0.2334(0.2557) Grad: 2.7643 Epoch: [1][16120/40998] Data 0.553 (0.553) Elapsed 229m 15s (remain 353m 46s) Loss: 0.1083(0.2556) Grad: 2.2074 Epoch: [1][16140/40998] Data 0.553 (0.553) Elapsed 229m 32s (remain 353m 29s) Loss: 0.1463(0.2555) Grad: 1.8759 Epoch: [1][16160/40998] Data 0.553 (0.553) Elapsed 229m 49s (remain 353m 12s) Loss: 0.1010(0.2554) Grad: 2.0354 Epoch: [1][16180/40998] Data 0.553 (0.553) Elapsed 230m 6s (remain 352m 55s) Loss: 0.3762(0.2554) Grad: 6.5814 Epoch: [1][16200/40998] Data 0.553 (0.553) Elapsed 230m 23s (remain 352m 38s) Loss: 0.1587(0.2552) Grad: 1.7839 Epoch: [1][16220/40998] Data 0.553 (0.553) Elapsed 230m 40s (remain 352m 21s) Loss: 0.1720(0.2551) Grad: 2.5345 Epoch: [1][16240/40998] Data 0.552 (0.553) Elapsed 230m 57s (remain 352m 4s) Loss: 0.1951(0.2550) Grad: 3.1240 Epoch: [1][16260/40998] Data 0.553 (0.553) Elapsed 231m 14s (remain 351m 47s) Loss: 0.1698(0.2549) Grad: 3.8891 Epoch: [1][16280/40998] Data 0.553 (0.553) Elapsed 231m 31s (remain 351m 30s) Loss: 0.3463(0.2548) Grad: 3.1023 Epoch: [1][16300/40998] Data 0.553 (0.553) Elapsed 231m 49s (remain 351m 13s) Loss: 0.1600(0.2548) Grad: 2.4028 Epoch: [1][16320/40998] Data 0.553 (0.553) Elapsed 232m 6s (remain 350m 55s) Loss: 0.4809(0.2547) Grad: 5.3612 Epoch: [1][16340/40998] Data 0.553 (0.553) Elapsed 232m 23s (remain 350m 38s) Loss: 0.1293(0.2546) Grad: 1.3372 Epoch: [1][16360/40998] Data 0.553 (0.553) Elapsed 232m 40s (remain 350m 21s) Loss: 0.2961(0.2545) Grad: 4.6070 Epoch: [1][16380/40998] Data 0.552 (0.553) Elapsed 232m 57s (remain 350m 4s) Loss: 0.1762(0.2545) Grad: 1.4548 Epoch: [1][16400/40998] Data 0.552 (0.553) Elapsed 233m 14s (remain 349m 47s) Loss: 0.1788(0.2544) Grad: 1.6819 Epoch: [1][16420/40998] Data 0.553 (0.553) Elapsed 233m 31s (remain 349m 30s) Loss: 0.3669(0.2543) Grad: 7.9949 Epoch: [1][16440/40998] Data 0.553 (0.553) Elapsed 233m 48s (remain 349m 13s) Loss: 0.1853(0.2542) Grad: 2.2960 Epoch: [1][16460/40998] Data 0.553 (0.553) Elapsed 234m 5s (remain 348m 56s) Loss: 0.1317(0.2541) Grad: 2.3917 Epoch: [1][16480/40998] Data 0.553 (0.553) Elapsed 234m 22s (remain 348m 39s) Loss: 0.2448(0.2540) Grad: 2.8505 Epoch: [1][16500/40998] Data 0.553 (0.553) Elapsed 234m 39s (remain 348m 22s) Loss: 0.1250(0.2540) Grad: 2.7002 Epoch: [1][16520/40998] Data 0.552 (0.553) Elapsed 234m 56s (remain 348m 5s) Loss: 0.0731(0.2539) Grad: 1.0064 Epoch: [1][16540/40998] Data 0.553 (0.553) Elapsed 235m 13s (remain 347m 48s) Loss: 0.1491(0.2538) Grad: 2.1927 Epoch: [1][16560/40998] Data 0.552 (0.553) Elapsed 235m 30s (remain 347m 31s) Loss: 0.5655(0.2537) Grad: 4.6435 Epoch: [1][16580/40998] Data 0.553 (0.553) Elapsed 235m 47s (remain 347m 14s) Loss: 0.0820(0.2537) Grad: 1.8205 Epoch: [1][16600/40998] Data 0.553 (0.553) Elapsed 236m 4s (remain 346m 56s) Loss: 0.2811(0.2536) Grad: 6.2083 Epoch: [1][16620/40998] Data 0.553 (0.553) Elapsed 236m 22s (remain 346m 39s) Loss: 0.0928(0.2535) Grad: 1.6256 Epoch: [1][16640/40998] Data 0.552 (0.553) Elapsed 236m 39s (remain 346m 22s) Loss: 0.0480(0.2534) Grad: 0.7441 Epoch: [1][16660/40998] Data 0.553 (0.553) Elapsed 236m 56s (remain 346m 5s) Loss: 0.2888(0.2534) Grad: 3.9175 Epoch: [1][16680/40998] Data 0.553 (0.553) Elapsed 237m 13s (remain 345m 48s) Loss: 0.3647(0.2533) Grad: 3.3249 Epoch: [1][16700/40998] Data 0.553 (0.553) Elapsed 237m 30s (remain 345m 31s) Loss: 0.0516(0.2531) Grad: 0.8696 Epoch: [1][16720/40998] Data 0.553 (0.553) Elapsed 237m 47s (remain 345m 14s) Loss: 0.1479(0.2531) Grad: 2.9752 Epoch: [1][16740/40998] Data 0.553 (0.553) Elapsed 238m 4s (remain 344m 57s) Loss: 0.4763(0.2530) Grad: 5.1271 Epoch: [1][16760/40998] Data 0.553 (0.553) Elapsed 238m 21s (remain 344m 40s) Loss: 0.1313(0.2529) Grad: 1.9964 Epoch: [1][16780/40998] Data 0.553 (0.553) Elapsed 238m 38s (remain 344m 23s) Loss: 0.0684(0.2528) Grad: 1.3994 Epoch: [1][16800/40998] Data 0.552 (0.553) Elapsed 238m 55s (remain 344m 6s) Loss: 0.1320(0.2527) Grad: 1.5602 Epoch: [1][16820/40998] Data 0.553 (0.553) Elapsed 239m 12s (remain 343m 49s) Loss: 0.3163(0.2526) Grad: 3.2209 Epoch: [1][16840/40998] Data 0.552 (0.553) Elapsed 239m 29s (remain 343m 32s) Loss: 0.0771(0.2525) Grad: 0.6902 Epoch: [1][16860/40998] Data 0.553 (0.553) Elapsed 239m 46s (remain 343m 15s) Loss: 0.0243(0.2524) Grad: 0.5582 Epoch: [1][16880/40998] Data 0.553 (0.553) Elapsed 240m 3s (remain 342m 57s) Loss: 0.2015(0.2523) Grad: 1.9992 Epoch: [1][16900/40998] Data 0.553 (0.553) Elapsed 240m 20s (remain 342m 40s) Loss: 0.3221(0.2523) Grad: 3.6064 Epoch: [1][16920/40998] Data 0.553 (0.553) Elapsed 240m 37s (remain 342m 23s) Loss: 0.3331(0.2522) Grad: 3.8411 Epoch: [1][16940/40998] Data 0.553 (0.553) Elapsed 240m 54s (remain 342m 6s) Loss: 0.1680(0.2520) Grad: 1.7458 Epoch: [1][16960/40998] Data 0.553 (0.553) Elapsed 241m 12s (remain 341m 49s) Loss: 0.1518(0.2519) Grad: 4.4711 Epoch: [1][16980/40998] Data 0.553 (0.553) Elapsed 241m 29s (remain 341m 32s) Loss: 0.0633(0.2518) Grad: 1.2474 Epoch: [1][17000/40998] Data 0.553 (0.553) Elapsed 241m 46s (remain 341m 15s) Loss: 0.1849(0.2517) Grad: 1.9928 Epoch: [1][17020/40998] Data 0.553 (0.553) Elapsed 242m 3s (remain 340m 58s) Loss: 0.1266(0.2517) Grad: 2.7471 Epoch: [1][17040/40998] Data 0.553 (0.553) Elapsed 242m 20s (remain 340m 41s) Loss: 0.0536(0.2516) Grad: 0.5454 Epoch: [1][17060/40998] Data 0.553 (0.553) Elapsed 242m 37s (remain 340m 24s) Loss: 0.1152(0.2515) Grad: 1.4696 Epoch: [1][17080/40998] Data 0.552 (0.553) Elapsed 242m 54s (remain 340m 7s) Loss: 0.1169(0.2514) Grad: 3.1552 Epoch: [1][17100/40998] Data 0.553 (0.553) Elapsed 243m 11s (remain 339m 50s) Loss: 0.1567(0.2513) Grad: 1.6248 Epoch: [1][17120/40998] Data 0.552 (0.553) Elapsed 243m 28s (remain 339m 33s) Loss: 0.2288(0.2512) Grad: 4.0448 Epoch: [1][17140/40998] Data 0.552 (0.553) Elapsed 243m 45s (remain 339m 16s) Loss: 0.1395(0.2512) Grad: 2.3428 Epoch: [1][17160/40998] Data 0.553 (0.553) Elapsed 244m 2s (remain 338m 58s) Loss: 0.4210(0.2511) Grad: 3.0639 Epoch: [1][17180/40998] Data 0.553 (0.553) Elapsed 244m 19s (remain 338m 41s) Loss: 0.2881(0.2511) Grad: 3.2067 Epoch: [1][17200/40998] Data 0.553 (0.553) Elapsed 244m 36s (remain 338m 24s) Loss: 0.0448(0.2510) Grad: 0.6094 Epoch: [1][17220/40998] Data 0.553 (0.553) Elapsed 244m 53s (remain 338m 7s) Loss: 0.1482(0.2508) Grad: 2.8622 Epoch: [1][17240/40998] Data 0.553 (0.553) Elapsed 245m 10s (remain 337m 50s) Loss: 0.1655(0.2508) Grad: 1.6844 Epoch: [1][17260/40998] Data 0.552 (0.553) Elapsed 245m 27s (remain 337m 33s) Loss: 0.1472(0.2507) Grad: 2.3813 Epoch: [1][17280/40998] Data 0.553 (0.553) Elapsed 245m 45s (remain 337m 16s) Loss: 0.3899(0.2506) Grad: 3.2701 Epoch: [1][17300/40998] Data 0.553 (0.553) Elapsed 246m 2s (remain 336m 59s) Loss: 0.0811(0.2506) Grad: 0.5731 Epoch: [1][17320/40998] Data 0.553 (0.553) Elapsed 246m 19s (remain 336m 42s) Loss: 0.1597(0.2504) Grad: 1.9984 Epoch: [1][17340/40998] Data 0.553 (0.553) Elapsed 246m 36s (remain 336m 25s) Loss: 0.0406(0.2504) Grad: 0.9213 Epoch: [1][17360/40998] Data 0.552 (0.553) Elapsed 246m 53s (remain 336m 8s) Loss: 0.2431(0.2503) Grad: 2.9174 Epoch: [1][17380/40998] Data 0.553 (0.553) Elapsed 247m 10s (remain 335m 51s) Loss: 0.2570(0.2502) Grad: 4.4230 Epoch: [1][17400/40998] Data 0.553 (0.553) Elapsed 247m 27s (remain 335m 34s) Loss: 0.0810(0.2501) Grad: 0.9942 Epoch: [1][17420/40998] Data 0.553 (0.553) Elapsed 247m 44s (remain 335m 17s) Loss: 0.0589(0.2501) Grad: 1.0411 Epoch: [1][17440/40998] Data 0.552 (0.553) Elapsed 248m 1s (remain 334m 59s) Loss: 0.0639(0.2500) Grad: 1.6222 Epoch: [1][17460/40998] Data 0.553 (0.553) Elapsed 248m 18s (remain 334m 42s) Loss: 0.0288(0.2499) Grad: 0.3935 Epoch: [1][17480/40998] Data 0.553 (0.553) Elapsed 248m 35s (remain 334m 25s) Loss: 0.1202(0.2498) Grad: 1.3275 Epoch: [1][17500/40998] Data 0.553 (0.553) Elapsed 248m 52s (remain 334m 8s) Loss: 0.1213(0.2497) Grad: 2.3754 Epoch: [1][17520/40998] Data 0.553 (0.553) Elapsed 249m 9s (remain 333m 51s) Loss: 0.0334(0.2496) Grad: 0.5979 Epoch: [1][17540/40998] Data 0.552 (0.553) Elapsed 249m 26s (remain 333m 34s) Loss: 0.0470(0.2495) Grad: 1.3209 Epoch: [1][17560/40998] Data 0.553 (0.553) Elapsed 249m 43s (remain 333m 17s) Loss: 0.0873(0.2494) Grad: 1.1118 Epoch: [1][17580/40998] Data 0.552 (0.553) Elapsed 250m 0s (remain 333m 0s) Loss: 0.1014(0.2493) Grad: 2.4408 Epoch: [1][17600/40998] Data 0.552 (0.553) Elapsed 250m 17s (remain 332m 43s) Loss: 0.1332(0.2493) Grad: 4.1381 Epoch: [1][17620/40998] Data 0.553 (0.553) Elapsed 250m 35s (remain 332m 26s) Loss: 0.0795(0.2491) Grad: 1.2570 Epoch: [1][17640/40998] Data 0.553 (0.553) Elapsed 250m 52s (remain 332m 9s) Loss: 0.1161(0.2490) Grad: 1.9118 Epoch: [1][17660/40998] Data 0.553 (0.553) Elapsed 251m 9s (remain 331m 52s) Loss: 0.1031(0.2490) Grad: 1.3117 Epoch: [1][17680/40998] Data 0.552 (0.553) Elapsed 251m 26s (remain 331m 35s) Loss: 0.1873(0.2489) Grad: 2.8792 Epoch: [1][17700/40998] Data 0.553 (0.553) Elapsed 251m 43s (remain 331m 18s) Loss: 0.1755(0.2488) Grad: 3.4949 Epoch: [1][17720/40998] Data 0.553 (0.553) Elapsed 252m 0s (remain 331m 0s) Loss: 0.2286(0.2487) Grad: 2.9451 Epoch: [1][17740/40998] Data 0.553 (0.553) Elapsed 252m 17s (remain 330m 43s) Loss: 0.3279(0.2487) Grad: 2.4857 Epoch: [1][17760/40998] Data 0.553 (0.553) Elapsed 252m 34s (remain 330m 26s) Loss: 0.2772(0.2486) Grad: 3.0324 Epoch: [1][17780/40998] Data 0.552 (0.553) Elapsed 252m 51s (remain 330m 9s) Loss: 0.0954(0.2486) Grad: 1.4501 Epoch: [1][17800/40998] Data 0.553 (0.553) Elapsed 253m 8s (remain 329m 52s) Loss: 0.1740(0.2485) Grad: 1.1246 Epoch: [1][17820/40998] Data 0.552 (0.553) Elapsed 253m 25s (remain 329m 35s) Loss: 0.0844(0.2484) Grad: 1.5910 Epoch: [1][17840/40998] Data 0.553 (0.553) Elapsed 253m 42s (remain 329m 18s) Loss: 0.2233(0.2482) Grad: 2.6648 Epoch: [1][17860/40998] Data 0.552 (0.553) Elapsed 253m 59s (remain 329m 1s) Loss: 0.2347(0.2482) Grad: 4.7091 Epoch: [1][17880/40998] Data 0.553 (0.553) Elapsed 254m 16s (remain 328m 44s) Loss: 0.1699(0.2482) Grad: 1.5545 Epoch: [1][17900/40998] Data 0.552 (0.553) Elapsed 254m 33s (remain 328m 27s) Loss: 0.1253(0.2481) Grad: 1.7767 Epoch: [1][17920/40998] Data 0.553 (0.553) Elapsed 254m 50s (remain 328m 10s) Loss: 0.0736(0.2480) Grad: 1.0490 Epoch: [1][17940/40998] Data 0.553 (0.553) Elapsed 255m 8s (remain 327m 53s) Loss: 0.0950(0.2479) Grad: 0.9803 Epoch: [1][17960/40998] Data 0.553 (0.553) Elapsed 255m 25s (remain 327m 36s) Loss: 0.0361(0.2478) Grad: 0.8263 Epoch: [1][17980/40998] Data 0.553 (0.553) Elapsed 255m 42s (remain 327m 19s) Loss: 0.1219(0.2477) Grad: 1.7528 Epoch: [1][18000/40998] Data 0.553 (0.553) Elapsed 255m 59s (remain 327m 1s) Loss: 0.4527(0.2476) Grad: 3.1047 Epoch: [1][18020/40998] Data 0.553 (0.553) Elapsed 256m 16s (remain 326m 44s) Loss: 0.1700(0.2474) Grad: 2.7845 Epoch: [1][18040/40998] Data 0.553 (0.553) Elapsed 256m 33s (remain 326m 27s) Loss: 0.1530(0.2473) Grad: 4.6904 Epoch: [1][18060/40998] Data 0.552 (0.553) Elapsed 256m 50s (remain 326m 10s) Loss: 0.1233(0.2472) Grad: 1.7688 Epoch: [1][18080/40998] Data 0.553 (0.553) Elapsed 257m 7s (remain 325m 53s) Loss: 0.1040(0.2471) Grad: 1.4221 Epoch: [1][18100/40998] Data 0.552 (0.553) Elapsed 257m 24s (remain 325m 36s) Loss: 0.0451(0.2470) Grad: 0.9076 Epoch: [1][18120/40998] Data 0.553 (0.553) Elapsed 257m 41s (remain 325m 19s) Loss: 0.0446(0.2469) Grad: 0.7721 Epoch: [1][18140/40998] Data 0.553 (0.553) Elapsed 257m 58s (remain 325m 2s) Loss: 0.2053(0.2469) Grad: 2.2058 Epoch: [1][18160/40998] Data 0.553 (0.553) Elapsed 258m 15s (remain 324m 45s) Loss: 0.1498(0.2467) Grad: 1.9591 Epoch: [1][18180/40998] Data 0.553 (0.553) Elapsed 258m 32s (remain 324m 28s) Loss: 0.0827(0.2466) Grad: 2.1156 Epoch: [1][18200/40998] Data 0.554 (0.553) Elapsed 258m 49s (remain 324m 11s) Loss: 0.2338(0.2466) Grad: 3.8750 Epoch: [1][18220/40998] Data 0.553 (0.553) Elapsed 259m 6s (remain 323m 54s) Loss: 0.1064(0.2465) Grad: 1.6336 Epoch: [1][18240/40998] Data 0.553 (0.553) Elapsed 259m 23s (remain 323m 37s) Loss: 0.0426(0.2464) Grad: 0.5867 Epoch: [1][18260/40998] Data 0.552 (0.553) Elapsed 259m 41s (remain 323m 20s) Loss: 0.5421(0.2463) Grad: 7.1349 Epoch: [1][18280/40998] Data 0.552 (0.553) Elapsed 259m 58s (remain 323m 3s) Loss: 0.1671(0.2462) Grad: 4.0544 Epoch: [1][18300/40998] Data 0.553 (0.553) Elapsed 260m 15s (remain 322m 45s) Loss: 0.1459(0.2461) Grad: 1.7394 Epoch: [1][18320/40998] Data 0.553 (0.553) Elapsed 260m 32s (remain 322m 28s) Loss: 0.0829(0.2460) Grad: 1.5254 Epoch: [1][18340/40998] Data 0.553 (0.553) Elapsed 260m 49s (remain 322m 11s) Loss: 0.1042(0.2459) Grad: 2.5611 Epoch: [1][18360/40998] Data 0.553 (0.553) Elapsed 261m 6s (remain 321m 54s) Loss: 0.0661(0.2459) Grad: 1.0508 Epoch: [1][18380/40998] Data 0.552 (0.553) Elapsed 261m 23s (remain 321m 37s) Loss: 0.0505(0.2458) Grad: 0.7083 Epoch: [1][18400/40998] Data 0.552 (0.553) Elapsed 261m 40s (remain 321m 20s) Loss: 0.3217(0.2457) Grad: 2.5998 Epoch: [1][18420/40998] Data 0.553 (0.553) Elapsed 261m 57s (remain 321m 3s) Loss: 0.0896(0.2456) Grad: 1.1268 Epoch: [1][18440/40998] Data 0.553 (0.553) Elapsed 262m 14s (remain 320m 46s) Loss: 0.1998(0.2455) Grad: 2.3304 Epoch: [1][18460/40998] Data 0.553 (0.553) Elapsed 262m 31s (remain 320m 29s) Loss: 0.1079(0.2454) Grad: 2.0039 Epoch: [1][18480/40998] Data 0.553 (0.553) Elapsed 262m 48s (remain 320m 12s) Loss: 0.1348(0.2453) Grad: 2.3925 Epoch: [1][18500/40998] Data 0.552 (0.553) Elapsed 263m 5s (remain 319m 55s) Loss: 0.1331(0.2452) Grad: 2.6048 Epoch: [1][18520/40998] Data 0.552 (0.553) Elapsed 263m 22s (remain 319m 38s) Loss: 0.1928(0.2451) Grad: 1.7789 Epoch: [1][18540/40998] Data 0.553 (0.553) Elapsed 263m 39s (remain 319m 21s) Loss: 0.1155(0.2450) Grad: 2.1506 Epoch: [1][18560/40998] Data 0.553 (0.553) Elapsed 263m 56s (remain 319m 4s) Loss: 0.0387(0.2450) Grad: 0.5673 Epoch: [1][18580/40998] Data 0.553 (0.553) Elapsed 264m 13s (remain 318m 46s) Loss: 0.2733(0.2449) Grad: 2.6358 Epoch: [1][18600/40998] Data 0.553 (0.553) Elapsed 264m 31s (remain 318m 29s) Loss: 0.2935(0.2448) Grad: 3.2840 Epoch: [1][18620/40998] Data 0.553 (0.553) Elapsed 264m 48s (remain 318m 12s) Loss: 0.2186(0.2448) Grad: 4.8716 Epoch: [1][18640/40998] Data 0.553 (0.553) Elapsed 265m 5s (remain 317m 55s) Loss: 0.1864(0.2447) Grad: 2.4338 Epoch: [1][18660/40998] Data 0.553 (0.553) Elapsed 265m 22s (remain 317m 38s) Loss: 0.1841(0.2446) Grad: 2.1224 Epoch: [1][18680/40998] Data 0.552 (0.553) Elapsed 265m 39s (remain 317m 21s) Loss: 0.2263(0.2445) Grad: 3.0600 Epoch: [1][18700/40998] Data 0.553 (0.553) Elapsed 265m 56s (remain 317m 4s) Loss: 0.3420(0.2444) Grad: 2.1759 Epoch: [1][18720/40998] Data 0.552 (0.553) Elapsed 266m 13s (remain 316m 47s) Loss: 0.2274(0.2443) Grad: 4.0982 Epoch: [1][18740/40998] Data 0.553 (0.553) Elapsed 266m 30s (remain 316m 30s) Loss: 0.3606(0.2443) Grad: 4.0770 Epoch: [1][18760/40998] Data 0.553 (0.553) Elapsed 266m 47s (remain 316m 13s) Loss: 0.0340(0.2442) Grad: 0.5353 Epoch: [1][18780/40998] Data 0.553 (0.553) Elapsed 267m 4s (remain 315m 56s) Loss: 0.1919(0.2441) Grad: 3.8375 Epoch: [1][18800/40998] Data 0.553 (0.553) Elapsed 267m 21s (remain 315m 39s) Loss: 0.1648(0.2440) Grad: 2.5278 Epoch: [1][18820/40998] Data 0.552 (0.553) Elapsed 267m 38s (remain 315m 22s) Loss: 0.0451(0.2439) Grad: 0.8171 Epoch: [1][18840/40998] Data 0.552 (0.553) Elapsed 267m 55s (remain 315m 5s) Loss: 0.2507(0.2438) Grad: 2.6086 Epoch: [1][18860/40998] Data 0.553 (0.553) Elapsed 268m 12s (remain 314m 47s) Loss: 0.1341(0.2437) Grad: 1.7936 Epoch: [1][18880/40998] Data 0.553 (0.553) Elapsed 268m 29s (remain 314m 30s) Loss: 0.2315(0.2437) Grad: 4.1346 Epoch: [1][18900/40998] Data 0.552 (0.553) Elapsed 268m 46s (remain 314m 13s) Loss: 0.3452(0.2435) Grad: 2.8581 Epoch: [1][18920/40998] Data 0.553 (0.553) Elapsed 269m 3s (remain 313m 56s) Loss: 0.1273(0.2435) Grad: 1.6399 Epoch: [1][18940/40998] Data 0.553 (0.553) Elapsed 269m 21s (remain 313m 39s) Loss: 0.1423(0.2434) Grad: 1.7945 Epoch: [1][18960/40998] Data 0.552 (0.553) Elapsed 269m 38s (remain 313m 22s) Loss: 0.0461(0.2433) Grad: 0.6477 Epoch: [1][18980/40998] Data 0.553 (0.553) Elapsed 269m 55s (remain 313m 5s) Loss: 0.0917(0.2432) Grad: 1.9324 Epoch: [1][19000/40998] Data 0.553 (0.553) Elapsed 270m 12s (remain 312m 48s) Loss: 0.0565(0.2432) Grad: 1.1468 Epoch: [1][19020/40998] Data 0.553 (0.553) Elapsed 270m 29s (remain 312m 31s) Loss: 0.2122(0.2431) Grad: 5.7390 Epoch: [1][19040/40998] Data 0.553 (0.553) Elapsed 270m 46s (remain 312m 14s) Loss: 0.3473(0.2430) Grad: 3.8021 Epoch: [1][19060/40998] Data 0.553 (0.553) Elapsed 271m 3s (remain 311m 57s) Loss: 0.2280(0.2429) Grad: 3.0202 Epoch: [1][19080/40998] Data 0.553 (0.553) Elapsed 271m 20s (remain 311m 40s) Loss: 0.1162(0.2429) Grad: 1.2837 Epoch: [1][19100/40998] Data 0.553 (0.553) Elapsed 271m 37s (remain 311m 23s) Loss: 0.0909(0.2428) Grad: 1.5304 Epoch: [1][19120/40998] Data 0.552 (0.553) Elapsed 271m 54s (remain 311m 6s) Loss: 0.1120(0.2427) Grad: 1.6885 Epoch: [1][19140/40998] Data 0.553 (0.553) Elapsed 272m 11s (remain 310m 49s) Loss: 0.1221(0.2427) Grad: 2.0557 Epoch: [1][19160/40998] Data 0.553 (0.553) Elapsed 272m 28s (remain 310m 31s) Loss: 0.2757(0.2425) Grad: 2.7100 Epoch: [1][19180/40998] Data 0.553 (0.553) Elapsed 272m 45s (remain 310m 14s) Loss: 0.0795(0.2425) Grad: 1.6499 Epoch: [1][19200/40998] Data 0.552 (0.553) Elapsed 273m 2s (remain 309m 57s) Loss: 0.2080(0.2424) Grad: 1.6410 Epoch: [1][19220/40998] Data 0.553 (0.553) Elapsed 273m 19s (remain 309m 40s) Loss: 0.0771(0.2423) Grad: 0.7612 Epoch: [1][19240/40998] Data 0.553 (0.553) Elapsed 273m 36s (remain 309m 23s) Loss: 0.4706(0.2423) Grad: 4.3592 Epoch: [1][19260/40998] Data 0.553 (0.553) Elapsed 273m 54s (remain 309m 6s) Loss: 0.3268(0.2422) Grad: 2.8371 Epoch: [1][19280/40998] Data 0.553 (0.553) Elapsed 274m 11s (remain 308m 49s) Loss: 0.3448(0.2421) Grad: 4.6229 Epoch: [1][19300/40998] Data 0.552 (0.553) Elapsed 274m 28s (remain 308m 32s) Loss: 0.0629(0.2420) Grad: 0.8713 Epoch: [1][19320/40998] Data 0.553 (0.553) Elapsed 274m 45s (remain 308m 15s) Loss: 0.2597(0.2420) Grad: 3.0607 Epoch: [1][19340/40998] Data 0.553 (0.553) Elapsed 275m 2s (remain 307m 58s) Loss: 0.1673(0.2419) Grad: 2.5576 Epoch: [1][19360/40998] Data 0.552 (0.553) Elapsed 275m 19s (remain 307m 41s) Loss: 0.2655(0.2418) Grad: 2.3812 Epoch: [1][19380/40998] Data 0.553 (0.553) Elapsed 275m 36s (remain 307m 24s) Loss: 0.0996(0.2417) Grad: 2.2970 Epoch: [1][19400/40998] Data 0.553 (0.553) Elapsed 275m 53s (remain 307m 7s) Loss: 0.0725(0.2417) Grad: 0.9698 Epoch: [1][19420/40998] Data 0.553 (0.553) Elapsed 276m 10s (remain 306m 50s) Loss: 0.0410(0.2416) Grad: 0.6932 Epoch: [1][19440/40998] Data 0.553 (0.553) Elapsed 276m 27s (remain 306m 33s) Loss: 0.1619(0.2415) Grad: 1.6838 Epoch: [1][19460/40998] Data 0.553 (0.553) Elapsed 276m 44s (remain 306m 15s) Loss: 0.3124(0.2415) Grad: 3.8552 Epoch: [1][19480/40998] Data 0.553 (0.553) Elapsed 277m 1s (remain 305m 58s) Loss: 0.1986(0.2414) Grad: 2.7909 Epoch: [1][19500/40998] Data 0.553 (0.553) Elapsed 277m 18s (remain 305m 41s) Loss: 0.2110(0.2414) Grad: 2.3025 Epoch: [1][19520/40998] Data 0.553 (0.553) Elapsed 277m 35s (remain 305m 24s) Loss: 0.0698(0.2413) Grad: 0.7284 Epoch: [1][19540/40998] Data 0.552 (0.553) Elapsed 277m 52s (remain 305m 7s) Loss: 0.1746(0.2412) Grad: 2.9130 Epoch: [1][19560/40998] Data 0.553 (0.553) Elapsed 278m 9s (remain 304m 50s) Loss: 0.0882(0.2412) Grad: 1.0842 Epoch: [1][19580/40998] Data 0.553 (0.553) Elapsed 278m 27s (remain 304m 33s) Loss: 0.1096(0.2411) Grad: 1.2281 Epoch: [1][19600/40998] Data 0.553 (0.553) Elapsed 278m 44s (remain 304m 16s) Loss: 0.0565(0.2410) Grad: 1.2948 Epoch: [1][19620/40998] Data 0.552 (0.553) Elapsed 279m 1s (remain 303m 59s) Loss: 0.1108(0.2409) Grad: 1.2629 Epoch: [1][19640/40998] Data 0.552 (0.553) Elapsed 279m 18s (remain 303m 42s) Loss: 0.1209(0.2408) Grad: 1.5887 Epoch: [1][19660/40998] Data 0.553 (0.553) Elapsed 279m 35s (remain 303m 25s) Loss: 0.1309(0.2408) Grad: 2.9003 Epoch: [1][19680/40998] Data 0.552 (0.553) Elapsed 279m 52s (remain 303m 8s) Loss: 0.1696(0.2407) Grad: 1.5461 Epoch: [1][19700/40998] Data 0.552 (0.553) Elapsed 280m 9s (remain 302m 51s) Loss: 0.2936(0.2406) Grad: 4.2503 Epoch: [1][19720/40998] Data 0.553 (0.553) Elapsed 280m 26s (remain 302m 34s) Loss: 0.4948(0.2405) Grad: 3.9454 Epoch: [1][19740/40998] Data 0.553 (0.553) Elapsed 280m 43s (remain 302m 17s) Loss: 0.1811(0.2404) Grad: 2.4617 Epoch: [1][19760/40998] Data 0.553 (0.553) Elapsed 281m 0s (remain 301m 59s) Loss: 0.0719(0.2404) Grad: 0.8475 Epoch: [1][19780/40998] Data 0.552 (0.553) Elapsed 281m 17s (remain 301m 42s) Loss: 0.0738(0.2402) Grad: 0.7520 Epoch: [1][19800/40998] Data 0.553 (0.553) Elapsed 281m 34s (remain 301m 25s) Loss: 0.0722(0.2402) Grad: 1.4441 Epoch: [1][19820/40998] Data 0.553 (0.553) Elapsed 281m 51s (remain 301m 8s) Loss: 0.0865(0.2401) Grad: 1.5266 Epoch: [1][19840/40998] Data 0.553 (0.553) Elapsed 282m 8s (remain 300m 51s) Loss: 0.1173(0.2400) Grad: 1.5981 Epoch: [1][19860/40998] Data 0.553 (0.553) Elapsed 282m 25s (remain 300m 34s) Loss: 0.0407(0.2399) Grad: 0.8835 Epoch: [1][19880/40998] Data 0.553 (0.553) Elapsed 282m 42s (remain 300m 17s) Loss: 0.2849(0.2398) Grad: 2.5711 Epoch: [1][19900/40998] Data 0.553 (0.553) Elapsed 283m 0s (remain 300m 0s) Loss: 0.7922(0.2398) Grad: 15.4472 Epoch: [1][19920/40998] Data 0.553 (0.553) Elapsed 283m 17s (remain 299m 43s) Loss: 0.1902(0.2398) Grad: 1.2798 Epoch: [1][19940/40998] Data 0.553 (0.553) Elapsed 283m 34s (remain 299m 26s) Loss: 0.2658(0.2397) Grad: 2.5439 Epoch: [1][19960/40998] Data 0.553 (0.553) Elapsed 283m 51s (remain 299m 9s) Loss: 0.1158(0.2396) Grad: 2.5047 Epoch: [1][19980/40998] Data 0.552 (0.553) Elapsed 284m 8s (remain 298m 52s) Loss: 0.1399(0.2396) Grad: 1.3208